Skip to content

cross_correlation

File containing Fourier-slice based cross-correlation functions for 2DTM.

do_batched_orientation_cross_correlate(image_dft, template_dft, rotation_matrices, projective_filters)

Batched projection and cross-correlation with fixed (batched) filters.

NOTE: This function is similar to do_streamed_orientation_cross_correlate but it computes cross-correlation batches over the orientation space. For example, if there are 32 orientations to process and 10 different defocus values, then there would be a total of 10 batched-32 cross-correlations computed.

NOTE: that this function returns a cross-correlogram with "same" mode (i.e. the same size as the input image). See numpy correlate docs for more information.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original real-space shape of the template volume.

required
rotation_matrices Tensor

Rotation matrices to apply to the template volume. Has shape (num_orientations, 3, 3).

required
projective_filters Tensor

Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape (num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.

required

Returns:

Type Description
Tensor

Cross-correlation of the image with the template volume for each orientation and defocus value. Will have shape (num_Cs, num_defocus, num_orientations, H, W).

Source code in src/leopard_em/backend/cross_correlation.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def do_batched_orientation_cross_correlate(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,
    rotation_matrices: torch.Tensor,
    projective_filters: torch.Tensor,
) -> torch.Tensor:
    """Batched projection and cross-correlation with fixed (batched) filters.

    NOTE: This function is similar to `do_streamed_orientation_cross_correlate` but
    it computes cross-correlation batches over the orientation space. For example, if
    there are 32 orientations to process and 10 different defocus values, then there
    would be a total of 10 batched-32 cross-correlations computed.

    NOTE: that this function returns a cross-correlogram with "same" mode (i.e. the
    same size as the input image). See numpy correlate docs for more information.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original
        real-space shape of the template volume.
    rotation_matrices : torch.Tensor
        Rotation matrices to apply to the template volume. Has shape
        (num_orientations, 3, 3).
    projective_filters : torch.Tensor
        Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
        (num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.

    Returns
    -------
    torch.Tensor
        Cross-correlation of the image with the template volume for each
        orientation and defocus value. Will have shape
        (num_Cs, num_defocus, num_orientations, H, W).
    """
    # Accounting for RFFT shape
    projection_shape_real = (template_dft.shape[1], template_dft.shape[2] * 2 - 2)
    image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2)

    num_Cs = projective_filters.shape[0]  # pylint: disable=invalid-name
    num_defocus = projective_filters.shape[1]

    cross_correlation = torch.empty(
        size=(
            num_Cs,
            num_defocus,
            rotation_matrices.shape[0],
            *image_shape_real,
        ),
        dtype=image_dft.real.dtype,  # Deduce the real dtype from complex DFT
        device=image_dft.device,
    )

    # Extract central slice(s) from the template volume
    fourier_slice = extract_central_slices_rfft_3d(
        volume_rfft=template_dft,
        image_shape=(projection_shape_real[0],) * 3,  # NOTE: requires cubic template
        rotation_matrices=rotation_matrices,
    )
    fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
    fourier_slice[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)
    fourier_slice *= -1  # flip contrast

    # Apply the projective filters on a new batch dimension
    fourier_slice = fourier_slice[None, None, ...] * projective_filters[:, :, None, ...]

    # Inverse Fourier transform into real space and normalize
    projections = torch.fft.irfftn(fourier_slice, dim=(-2, -1))
    projections = torch.fft.ifftshift(projections, dim=(-2, -1))
    projections = normalize_template_projection_compiled(
        projections,
        projection_shape_real,
        image_shape_real,
    )

    for j in range(num_defocus):
        for k in range(num_Cs):
            projections_dft = torch.fft.rfftn(
                projections[k, j, ...], dim=(-2, -1), s=image_shape_real
            )
            projections_dft[..., 0, 0] = 0 + 0j

            # Cross correlation step by element-wise multiplication
            projections_dft = image_dft[None, ...] * projections_dft.conj()
            torch.fft.irfftn(
                projections_dft, dim=(-2, -1), out=cross_correlation[k, j, ...]
            )

    return cross_correlation

do_batched_orientation_cross_correlate_cpu(image_dft, template_dft, rotation_matrices, projective_filters)

Same as do_streamed_orientation_cross_correlate but on the CPU.

The only difference is that this function does not call into a compiled torch function for normalization.

TODO: Figure out a better way to split up CPU/GPU functions while remaining performant and not duplicating code.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1).

required
rotation_matrices Tensor

Rotation matrices to apply to the template volume. Has shape (orientations, 3, 3).

required
projective_filters Tensor

Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape (defocus_batch, h, w // 2 + 1). Is RFFT and not fftshifted.

required

Returns:

Type Description
Tensor

Cross-correlation for the batch of orientations and defocus values.s

Source code in src/leopard_em/backend/cross_correlation.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def do_batched_orientation_cross_correlate_cpu(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,
    rotation_matrices: torch.Tensor,
    projective_filters: torch.Tensor,
) -> torch.Tensor:
    """Same as `do_streamed_orientation_cross_correlate` but on the CPU.

    The only difference is that this function does not call into a compiled torch
    function for normalization.

    TODO: Figure out a better way to split up CPU/GPU functions while remaining
    performant and not duplicating code.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1).
    rotation_matrices : torch.Tensor
        Rotation matrices to apply to the template volume. Has shape
        (orientations, 3, 3).
    projective_filters : torch.Tensor
        Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
        (defocus_batch, h, w // 2 + 1). Is RFFT and not fftshifted.

    Returns
    -------
    torch.Tensor
        Cross-correlation for the batch of orientations and defocus values.s
    """
    # Accounting for RFFT shape
    projection_shape_real = (template_dft.shape[1], template_dft.shape[2] * 2 - 2)
    image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2)

    # Extract central slice(s) from the template volume
    fourier_slice = extract_central_slices_rfft_3d(
        volume_rfft=template_dft,
        image_shape=(projection_shape_real[0],) * 3,  # NOTE: requires cubic template
        rotation_matrices=rotation_matrices,
    )
    fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
    fourier_slice[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)
    fourier_slice *= -1  # flip contrast

    # Apply the projective filters on a new batch dimension
    fourier_slice = fourier_slice[None, None, ...] * projective_filters[:, :, None, ...]

    # Inverse Fourier transform into real space and normalize
    projections = torch.fft.irfftn(fourier_slice, dim=(-2, -1))
    projections = torch.fft.ifftshift(projections, dim=(-2, -1))
    projections = normalize_template_projection(
        projections,
        projection_shape_real,
        image_shape_real,
    )

    # Padded forward Fourier transform for cross-correlation
    projections_dft = torch.fft.rfftn(projections, dim=(-2, -1), s=image_shape_real)
    projections_dft[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)

    # Cross correlation step by element-wise multiplication
    projections_dft = image_dft[None, None, None, ...] * projections_dft.conj()
    cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))

    return cross_correlation

do_streamed_orientation_cross_correlate(image_dft, template_dft, rotation_matrices, projective_filters, streams)

Calculates a grid of 2D cross-correlations over multiple CUDA streams.

NOTE: This function is more performant than a batched 2D cross-correlation with shape (N, H, W) when the kernel (template) is much smaller than the image (e.g. kernel is 512x512 and image is 4096x4096). Each cross-correlation is computed individually and stored in a batched tensor for the grid of orientations, defoci, and pixel size values.

NOTE: this function returns a cross-correlogram with "same" mode (i.e. the same size as the input image). See numpy correlate docs for more information.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original real-space shape of the template volume.

required
rotation_matrices Tensor

Rotation matrices to apply to the template volume. Has shape (num_orientations, 3, 3).

required
projective_filters Tensor

Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape (num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.

required
streams list[Stream]

List of CUDA streams to use for parallel computation. Each stream will handle a separate cross-correlation.

required

Returns:

Type Description
Tensor

Cross-correlation of the image with the template volume for each orientation and defocus value. Will have shape (num_Cs, num_defocus, num_orientations, H, W).

Source code in src/leopard_em/backend/cross_correlation.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def do_streamed_orientation_cross_correlate(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,
    rotation_matrices: torch.Tensor,
    projective_filters: torch.Tensor,
    streams: list[torch.cuda.Stream],
) -> torch.Tensor:
    """Calculates a grid of 2D cross-correlations over multiple CUDA streams.

    NOTE: This function is more performant than a batched 2D cross-correlation with
    shape (N, H, W) when the kernel (template) is much smaller than the image (e.g.
    kernel is 512x512 and image is 4096x4096). Each cross-correlation is computed
    individually and stored in a batched tensor for the grid of orientations, defoci,
    and pixel size values.

    NOTE: this function returns a cross-correlogram with "same" mode (i.e. the
    same size as the input image). See numpy correlate docs for more information.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original
        real-space shape of the template volume.
    rotation_matrices : torch.Tensor
        Rotation matrices to apply to the template volume. Has shape
        (num_orientations, 3, 3).
    projective_filters : torch.Tensor
        Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
        (num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.
    streams : list[torch.cuda.Stream]
        List of CUDA streams to use for parallel computation. Each stream will
        handle a separate cross-correlation.

    Returns
    -------
    torch.Tensor
        Cross-correlation of the image with the template volume for each
        orientation and defocus value. Will have shape
        (num_Cs, num_defocus, num_orientations, H, W).
    """
    # Accounting for RFFT shape
    projection_shape_real = (template_dft.shape[1], template_dft.shape[2] * 2 - 2)
    image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2)

    num_orientations = rotation_matrices.shape[0]
    num_Cs = projective_filters.shape[0]  # pylint: disable=invalid-name
    num_defocus = projective_filters.shape[1]

    cross_correlation = torch.empty(
        size=(num_Cs, num_defocus, num_orientations, *image_shape_real),
        dtype=image_dft.real.dtype,  # Deduce the real dtype from complex DFT
        device=image_dft.device,
    )

    # Do a batched Fourier slice extraction for all the orientations at once.
    fourier_slices = extract_central_slices_rfft_3d(
        volume_rfft=template_dft,
        image_shape=(projection_shape_real[0],) * 3,
        rotation_matrices=rotation_matrices,
    )
    fourier_slices = torch.fft.ifftshift(fourier_slices, dim=(-2,))
    fourier_slices[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)
    fourier_slices *= -1  # flip contrast

    # Iterate over the orientations
    for i in range(num_orientations):
        fourier_slice = fourier_slices[i]

        # Iterate over the different pixel sizes (Cs) and defocus values for this
        # particular orientation
        for j in range(num_defocus):
            for k in range(num_Cs):
                # Use a round-robin scheduling for the streams
                job_idx = (i * num_defocus * num_Cs) + (j * num_Cs) + k
                stream_idx = job_idx % len(streams)
                stream = streams[stream_idx]

                with torch.cuda.stream(stream):
                    # Apply the projective filter and do template normalization
                    fourier_slice_filtered = fourier_slice * projective_filters[k, j]
                    projection = torch.fft.irfft2(fourier_slice_filtered)
                    projection = torch.fft.ifftshift(projection, dim=(-2, -1))
                    projection = normalize_template_projection_compiled(
                        projection,
                        projection_shape_real,
                        image_shape_real,
                    )

                    # Padded forward Fourier transform for cross-correlation
                    projection_dft = torch.fft.rfft2(projection, s=image_shape_real)
                    projection_dft[0, 0] = 0 + 0j

                    # Cross correlation step by element-wise multiplication
                    projection_dft = image_dft * projection_dft.conj()
                    torch.fft.irfft2(
                        projection_dft,
                        s=image_shape_real,
                        out=cross_correlation[k, j, i],
                    )

    # Wait for all streams to finish
    for stream in streams:
        stream.synchronize()

    # shape is (num_Cs, num_defocus, num_orientations, H, W)
    return cross_correlation