Skip to content

cross_correlation

Core cross-correlation methods for single and stacks of image/templates.

handle_correlation_mode(cross_correlation, out_shape, mode)

Handle cropping for cross correlation mode.

NOTE: 'full' mode is not implemented.

Parameters:

Name Type Description Default
cross_correlation Tensor

The cross correlation result.

required
out_shape tuple[int, ...]

The desired shape of the output.

required
mode Literal['valid', 'same']

The mode of the cross correlation. Either 'valid' or 'same'. See numpy.correlate for more details.

required
Source code in src/leopard_em/utils/cross_correlation.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def handle_correlation_mode(
    cross_correlation: torch.Tensor,
    out_shape: tuple[int, ...],
    mode: Literal["valid", "same"],
) -> torch.Tensor:
    """Handle cropping for cross correlation mode.

     NOTE: 'full' mode is not implemented.

    Parameters
    ----------
    cross_correlation : torch.Tensor
        The cross correlation result.
    out_shape : tuple[int, ...]
        The desired shape of the output.
    mode : Literal["valid", "same"]
        The mode of the cross correlation. Either 'valid' or 'same'. See
        [numpy.correlate](https://numpy.org/doc/stable/reference/generated/
        numpy.convolve.html#numpy.convolve)
        for more details.
    """
    # Crop the result to the valid bounds
    if mode == "valid":
        slices = [slice(0, _out_s) for _out_s in out_shape]
        cross_correlation = cross_correlation[slices]
    elif mode == "same":
        pass
    elif mode == "full":
        raise NotImplementedError("Full mode not supported")
    else:
        raise ValueError(f"Invalid mode: {mode}")

    return cross_correlation