Skip to content

core_match_template

Pure PyTorch implementation of whole orientation search backend.

construct_multi_gpu_match_template_kwargs(image_dft, template_dft, euler_angles, projective_filters, defocus_values, pixel_values, orientation_batch_size, num_cuda_streams, devices)

Split orientations between requested devices.

See the core_match_template function for further descriptions of the input parameters.

Parameters:

Name Type Description Default
image_dft Tensor

dft of image

required
template_dft Tensor

dft of template

required
euler_angles Tensor

euler angles to search

required
projective_filters Tensor

filters to apply to each projection

required
defocus_values Tensor

corresponding defocus values for each filter

required
pixel_values Tensor

corresponding pixel size values for each filter

required
orientation_batch_size int

number of projections to calculate at once

required
num_cuda_streams int

number of CUDA streams to use for parallelizing cross-correlation computation

required
devices list[device]

list of devices to split the orientations across

required

Returns:

Type Description
list[dict[str, Tensor | int]]

List of dictionaries containing the kwargs to call the single-GPU function. Each index in the list corresponds to a different device, and all tensors in the dictionary have been allocated to that device.

Source code in src/leopard_em/backend/core_match_template.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def construct_multi_gpu_match_template_kwargs(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,
    euler_angles: torch.Tensor,
    projective_filters: torch.Tensor,
    defocus_values: torch.Tensor,
    pixel_values: torch.Tensor,
    orientation_batch_size: int,
    num_cuda_streams: int,
    devices: list[torch.device],
) -> list[dict[str, torch.Tensor | torch.device | int]]:
    """Split orientations between requested devices.

    See the `core_match_template` function for further descriptions of the
    input parameters.

    Parameters
    ----------
    image_dft : torch.Tensor
        dft of image
    template_dft : torch.Tensor
        dft of template
    euler_angles : torch.Tensor
        euler angles to search
    projective_filters : torch.Tensor
        filters to apply to each projection
    defocus_values : torch.Tensor
        corresponding defocus values for each filter
    pixel_values : torch.Tensor
        corresponding pixel size values for each filter
    orientation_batch_size : int
        number of projections to calculate at once
    num_cuda_streams : int
        number of CUDA streams to use for parallelizing cross-correlation computation
    devices : list[torch.device]
        list of devices to split the orientations across

    Returns
    -------
    list[dict[str, torch.Tensor | int]]
        List of dictionaries containing the kwargs to call the single-GPU
        function. Each index in the list corresponds to a different device,
        and all tensors in the dictionary have been allocated to that device.
    """
    kwargs_per_device = []

    # Split the euler angles across devices
    euler_angles_split = euler_angles.chunk(len(devices))

    for device, euler_angles_device in zip(devices, euler_angles_split):
        # Allocate and construct the kwargs for this device
        kwargs = {
            "image_dft": image_dft,
            "template_dft": template_dft,
            "euler_angles": euler_angles_device,
            "projective_filters": projective_filters,
            "defocus_values": defocus_values,
            "pixel_values": pixel_values,
            "orientation_batch_size": orientation_batch_size,
            "num_cuda_streams": num_cuda_streams,
            "device": device,
        }

        kwargs_per_device.append(kwargs)

    return kwargs_per_device

core_match_template(image_dft, template_dft, ctf_filters, whitening_filter_template, defocus_values, pixel_values, euler_angles, device, orientation_batch_size=1, num_cuda_streams=1)

Core function for performing the whole-orientation search.

With the RFFT, the last dimension (fastest dimension) is half the width of the input, hence the shape of W // 2 + 1 instead of W for some of the input parameters.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the half-dimension for real-FFT transformation. NOTE: The original template volume should be a cubic volume, i.e. h == w == l.

required
ctf_filters Tensor

Stack of CTF filters at different pixel size (Cs) and defocus values to use in the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs are the number of pixel sizes searched over, and num_defocus are the number of defocus values searched over.

required
whitening_filter_template Tensor

Whitening filter for the template volume. Has shape (h, w // 2 + 1). Gets multiplied with the ctf filters to create a filter stack applied to each orientation projection.

required
euler_angles Tensor

Euler angles (in 'ZYZ' convention) to search over. Has shape (num_orientations, 3).

required
defocus_values Tensor

What defoucs values correspond with the CTF filters, in units of Angstroms. Has shape (num_defocus,).

required
pixel_values Tensor

What pixel size values correspond with the CTF filters, in units of Angstroms. Has shape (num_Cs,).

required
device device | list[device]

Device or devices to split computation across.

required
orientation_batch_size int

Number of projections, at different orientations, to calculate simultaneously. Larger values will use more memory, but can help amortize the cost of Fourier slice extraction. The default is 1, but generally values larger than 1 should be used for performance.

1
num_cuda_streams int

Number of CUDA streams to use for parallelizing cross-correlation computation. More streams can lead to better performance, especially for high-end GPUs, but the performance will degrade if too many streams are used. The default is 1 which performs well in most cases, but high-end GPUs can benefit from increasing this value. NOTE: If the number of streams is greater than the number of cross-correlations to compute per batch, then the number of streams will be reduced to the number of cross-correlations per batch. This is done to avoid unnecessary overhead and performance degradation.

1

Returns:

Type Description
dict[str, Tensor]

Dictionary containing the following key, value pairs:

- "mip": Maximum intensity projection of the cross-correlation values across
  orientation and defocus search space.
- "scaled_mip": Z-score scaled MIP of the cross-correlation values.
- "best_phi": Best phi angle for each pixel.
- "best_theta": Best theta angle for each pixel.
- "best_psi": Best psi angle for each pixel.
- "best_defocus": Best defocus value for each pixel.
- "best_pixel_size": Best pixel size value for each pixel.
- "correlation_sum": Sum of cross-correlation values for each pixel.
- "correlation_squared_sum": Sum of squared cross-correlation values for
  each pixel.
- "total_projections": Total number of projections calculated.
- "total_orientations": Total number of orientations searched.
- "total_defocus": Total number of defocus values searched.
Source code in src/leopard_em/backend/core_match_template.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def core_match_template(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,  # already fftshifted
    ctf_filters: torch.Tensor,
    whitening_filter_template: torch.Tensor,
    defocus_values: torch.Tensor,
    pixel_values: torch.Tensor,
    euler_angles: torch.Tensor,
    device: torch.device | list[torch.device],
    orientation_batch_size: int = 1,
    num_cuda_streams: int = 1,
) -> dict[str, torch.Tensor]:
    """Core function for performing the whole-orientation search.

    With the RFFT, the last dimension (fastest dimension) is half the width
    of the input, hence the shape of W // 2 + 1 instead of W for some of the
    input parameters.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the
        half-dimension for real-FFT transformation. NOTE: The original template volume
        should be a cubic volume, i.e. h == w == l.
    ctf_filters : torch.Tensor
        Stack of CTF filters at different pixel size (Cs) and  defocus values to use in
        the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs are the
        number of pixel sizes searched over, and num_defocus are the number of
        defocus values searched over.
    whitening_filter_template : torch.Tensor
        Whitening filter for the template volume. Has shape (h, w // 2 + 1).
        Gets multiplied with the ctf filters to create a filter stack applied to each
        orientation projection.
    euler_angles : torch.Tensor
        Euler angles (in 'ZYZ' convention) to search over. Has shape
        (num_orientations, 3).
    defocus_values : torch.Tensor
        What defoucs values correspond with the CTF filters, in units of Angstroms. Has
        shape (num_defocus,).
    pixel_values : torch.Tensor
        What pixel size values correspond with the CTF filters, in units of Angstroms.
        Has shape (num_Cs,).
    device : torch.device | list[torch.device]
        Device or devices to split computation across.
    orientation_batch_size : int, optional
        Number of projections, at different orientations, to calculate simultaneously.
        Larger values will use more memory, but can help amortize the cost of Fourier
        slice extraction. The default is 1, but generally values larger than 1 should
        be used for performance.
    num_cuda_streams : int, optional
        Number of CUDA streams to use for parallelizing cross-correlation computation.
        More streams can lead to better performance, especially for high-end GPUs, but
        the performance will degrade if too many streams are used. The default is 1
        which performs well in most cases, but high-end GPUs can benefit from
        increasing this value. NOTE: If the number of streams is greater than the
        number of cross-correlations to compute per batch, then the number of streams
        will be reduced to the number of cross-correlations per batch. This is done to
        avoid unnecessary overhead and performance degradation.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing the following key, value pairs:

            - "mip": Maximum intensity projection of the cross-correlation values across
              orientation and defocus search space.
            - "scaled_mip": Z-score scaled MIP of the cross-correlation values.
            - "best_phi": Best phi angle for each pixel.
            - "best_theta": Best theta angle for each pixel.
            - "best_psi": Best psi angle for each pixel.
            - "best_defocus": Best defocus value for each pixel.
            - "best_pixel_size": Best pixel size value for each pixel.
            - "correlation_sum": Sum of cross-correlation values for each pixel.
            - "correlation_squared_sum": Sum of squared cross-correlation values for
              each pixel.
            - "total_projections": Total number of projections calculated.
            - "total_orientations": Total number of orientations searched.
            - "total_defocus": Total number of defocus values searched.
    """
    ################################################################
    ### Initial checks for input parameters plus and adjustments ###
    ################################################################
    # If there are more streams than cross-correlations to compute per batch, then
    # reduce the number of streams to the number of cross-correlations per batch.
    total_cc_per_batch = (
        orientation_batch_size * defocus_values.shape[0] * pixel_values.shape[0]
    )
    if num_cuda_streams > total_cc_per_batch:
        warnings.warn(
            f"Number of CUDA streams ({num_cuda_streams}) is greater than the "
            f"number of cross-correlations per batch ({total_cc_per_batch}). "
            f"The total cross-correlations per batch is number of pixel sizes "
            f"({pixel_values.shape[0]}) * number of defocus values "
            f"({defocus_values.shape[0]}) * orientation batch size "
            f"({orientation_batch_size}). "
            f"Reducing number of streams to {total_cc_per_batch} for performance.",
            stacklevel=2,
        )
        num_cuda_streams = total_cc_per_batch

    # Ensure the tensors are all on the CPU. The _core_match_template_single_gpu
    # function will move them onto the correct device.
    image_dft = image_dft.cpu()
    template_dft = template_dft.cpu()
    ctf_filters = ctf_filters.cpu()
    whitening_filter_template = whitening_filter_template.cpu()
    defocus_values = defocus_values.cpu()
    pixel_values = pixel_values.cpu()
    euler_angles = euler_angles.cpu()

    ##############################################################
    ### Pre-multiply the whitening filter with the CTF filters ###
    ##############################################################

    projective_filters = ctf_filters * whitening_filter_template[None, None, ...]

    #########################################
    ### Split orientations across devices ###
    #########################################

    if isinstance(device, torch.device):
        device = [device]

    kwargs_per_device = construct_multi_gpu_match_template_kwargs(
        image_dft=image_dft,
        template_dft=template_dft,
        euler_angles=euler_angles,
        projective_filters=projective_filters,
        defocus_values=defocus_values,
        pixel_values=pixel_values,
        orientation_batch_size=orientation_batch_size,
        num_cuda_streams=num_cuda_streams,
        devices=device,
    )

    result_dict = run_multiprocess_jobs(
        target=_core_match_template_single_gpu,
        kwargs_list=kwargs_per_device,
    )

    # Get the aggregated results
    partial_results = [result_dict[i] for i in range(len(kwargs_per_device))]
    aggregated_results = aggregate_distributed_results(partial_results)
    mip = aggregated_results["mip"]
    best_phi = aggregated_results["best_phi"]
    best_theta = aggregated_results["best_theta"]
    best_psi = aggregated_results["best_psi"]
    best_defocus = aggregated_results["best_defocus"]
    correlation_sum = aggregated_results["correlation_sum"]
    correlation_squared_sum = aggregated_results["correlation_squared_sum"]
    total_projections = aggregated_results["total_projections"]

    mip_scaled = torch.empty_like(mip)
    mip, mip_scaled, correlation_mean, correlation_variance = scale_mip(
        mip=mip,
        mip_scaled=mip_scaled,
        correlation_sum=correlation_sum,
        correlation_squared_sum=correlation_squared_sum,
        total_correlation_positions=total_projections,
    )

    return {
        "mip": mip,
        "scaled_mip": mip_scaled,
        "best_phi": best_phi,
        "best_theta": best_theta,
        "best_psi": best_psi,
        "best_defocus": best_defocus,
        "correlation_mean": correlation_mean,
        "correlation_variance": correlation_variance,
        "total_projections": total_projections,
        "total_orientations": euler_angles.shape[0],
        "total_defocus": defocus_values.shape[0],
    }