Skip to content

core_match_template

Pure PyTorch implementation of whole orientation search backend.

core_match_template(image_dft, template_dft, ctf_filters, whitening_filter_template, defocus_values, pixel_values, euler_angles, device, orientation_batch_size=1, num_cuda_streams=1, backend='streamed')

Core function for performing the whole-orientation search.

With the RFFT, the last dimension (fastest dimension) is half the width of the input, hence the shape of W // 2 + 1 instead of W for some of the input parameters.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the half-dimension for real-FFT transformation. NOTE: The original template volume should be a cubic volume, i.e. h == w == l.

required
ctf_filters Tensor

Stack of CTF filters at different pixel size (Cs) and defocus values to use in the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs are the number of pixel sizes searched over, and num_defocus are the number of defocus values searched over.

required
whitening_filter_template Tensor

Whitening filter for the template volume. Has shape (h, w // 2 + 1). Gets multiplied with the ctf filters to create a filter stack applied to each orientation projection.

required
euler_angles Tensor

Euler angles (in 'ZYZ' convention & in units of degrees) to search over. Has shape (num_orientations, 3).

required
defocus_values Tensor

What defoucs values correspond with the CTF filters, in units of Angstroms. Has shape (num_defocus,).

required
pixel_values Tensor

What pixel size values correspond with the CTF filters, in units of Angstroms. Has shape (num_Cs,).

required
device device | list[device]

Device or devices to split computation across.

required
orientation_batch_size int

Number of projections, at different orientations, to calculate simultaneously. Larger values will use more memory, but can help amortize the cost of Fourier slice extraction. The default is 1, but generally values larger than 1 should be used for performance.

1
num_cuda_streams int

Number of CUDA streams to use for parallelizing cross-correlation computation. More streams can lead to better performance, especially for high-end GPUs, but the performance will degrade if too many streams are used. The default is 1 which performs well in most cases, but high-end GPUs can benefit from increasing this value. NOTE: If the number of streams is greater than the number of cross-correlations to compute per batch, then the number of streams will be reduced to the number of cross-correlations per batch. This is done to avoid unnecessary overhead and performance degradation.

1
backend str

The backend to use for computation. Defaults to 'streamed'. Must be 'streamed' or 'batched'.

'streamed'

Returns:

Type Description
dict[str, Tensor]

Dictionary containing the following key, value pairs:

- "mip": Maximum intensity projection of the cross-correlation values across
  orientation and defocus search space.
- "scaled_mip": Z-score scaled MIP of the cross-correlation values.
- "best_phi": Best phi angle for each pixel.
- "best_theta": Best theta angle for each pixel.
- "best_psi": Best psi angle for each pixel.
- "best_defocus": Best defocus value for each pixel.
- "best_pixel_size": Best pixel size value for each pixel.
- "correlation_sum": Sum of cross-correlation values for each pixel.
- "correlation_squared_sum": Sum of squared cross-correlation values for
  each pixel.
- "total_orientations": Total number of orientations searched.
- "total_defocus": Total number of defocus values searched.
Source code in src/leopard_em/backend/core_match_template.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def core_match_template(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,  # already fftshifted
    ctf_filters: torch.Tensor,
    whitening_filter_template: torch.Tensor,
    defocus_values: torch.Tensor,
    pixel_values: torch.Tensor,
    euler_angles: torch.Tensor,
    device: torch.device | list[torch.device],
    orientation_batch_size: int = 1,
    num_cuda_streams: int = 1,
    backend: str = "streamed",
) -> dict[str, torch.Tensor]:
    """Core function for performing the whole-orientation search.

    With the RFFT, the last dimension (fastest dimension) is half the width
    of the input, hence the shape of W // 2 + 1 instead of W for some of the
    input parameters.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the
        half-dimension for real-FFT transformation. NOTE: The original template volume
        should be a cubic volume, i.e. h == w == l.
    ctf_filters : torch.Tensor
        Stack of CTF filters at different pixel size (Cs) and  defocus values to use in
        the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs are the
        number of pixel sizes searched over, and num_defocus are the number of
        defocus values searched over.
    whitening_filter_template : torch.Tensor
        Whitening filter for the template volume. Has shape (h, w // 2 + 1).
        Gets multiplied with the ctf filters to create a filter stack applied to each
        orientation projection.
    euler_angles : torch.Tensor
        Euler angles (in 'ZYZ' convention & in units of degrees) to search over. Has
        shape (num_orientations, 3).
    defocus_values : torch.Tensor
        What defoucs values correspond with the CTF filters, in units of Angstroms. Has
        shape (num_defocus,).
    pixel_values : torch.Tensor
        What pixel size values correspond with the CTF filters, in units of Angstroms.
        Has shape (num_Cs,).
    device : torch.device | list[torch.device]
        Device or devices to split computation across.
    orientation_batch_size : int, optional
        Number of projections, at different orientations, to calculate simultaneously.
        Larger values will use more memory, but can help amortize the cost of Fourier
        slice extraction. The default is 1, but generally values larger than 1 should
        be used for performance.
    num_cuda_streams : int, optional
        Number of CUDA streams to use for parallelizing cross-correlation computation.
        More streams can lead to better performance, especially for high-end GPUs, but
        the performance will degrade if too many streams are used. The default is 1
        which performs well in most cases, but high-end GPUs can benefit from
        increasing this value. NOTE: If the number of streams is greater than the
        number of cross-correlations to compute per batch, then the number of streams
        will be reduced to the number of cross-correlations per batch. This is done to
        avoid unnecessary overhead and performance degradation.
    backend : str, optional
        The backend to use for computation. Defaults to 'streamed'.
        Must be 'streamed' or 'batched'.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing the following key, value pairs:

            - "mip": Maximum intensity projection of the cross-correlation values across
              orientation and defocus search space.
            - "scaled_mip": Z-score scaled MIP of the cross-correlation values.
            - "best_phi": Best phi angle for each pixel.
            - "best_theta": Best theta angle for each pixel.
            - "best_psi": Best psi angle for each pixel.
            - "best_defocus": Best defocus value for each pixel.
            - "best_pixel_size": Best pixel size value for each pixel.
            - "correlation_sum": Sum of cross-correlation values for each pixel.
            - "correlation_squared_sum": Sum of squared cross-correlation values for
              each pixel.
            - "total_orientations": Total number of orientations searched.
            - "total_defocus": Total number of defocus values searched.
    """
    ################################################################
    ### Initial checks for input parameters plus and adjustments ###
    ################################################################
    # If there are more streams than cross-correlations to compute per batch, then
    # reduce the number of streams to the number of cross-correlations per batch.
    total_cc_per_batch = (
        orientation_batch_size * defocus_values.shape[0] * pixel_values.shape[0]
    )
    if num_cuda_streams > total_cc_per_batch:
        warnings.warn(
            f"Number of CUDA streams ({num_cuda_streams}) is greater than the "
            f"number of cross-correlations per batch ({total_cc_per_batch}). "
            f"The total cross-correlations per batch is number of pixel sizes "
            f"({pixel_values.shape[0]}) * number of defocus values "
            f"({defocus_values.shape[0]}) * orientation batch size "
            f"({orientation_batch_size}). "
            f"Reducing number of streams to {total_cc_per_batch} for performance.",
            stacklevel=2,
        )
        num_cuda_streams = total_cc_per_batch

    # Ensure the tensors are all on the CPU. The _core_match_template_single_gpu
    # function will move them onto the correct device.
    image_dft = image_dft.cpu()
    template_dft = template_dft.cpu()
    ctf_filters = ctf_filters.cpu()
    whitening_filter_template = whitening_filter_template.cpu()
    defocus_values = defocus_values.cpu()
    pixel_values = pixel_values.cpu()
    euler_angles = euler_angles.cpu()

    ##############################################################
    ### Pre-multiply the whitening filter with the CTF filters ###
    ##############################################################

    projective_filters = ctf_filters * whitening_filter_template[None, None, ...]
    total_projections = (
        euler_angles.shape[0] * defocus_values.shape[0] * pixel_values.shape[0]
    )

    ############################################################
    ### Shared queue mechanism and multiprocessing arguments ###
    ############################################################

    if isinstance(device, torch.device):
        device = [device]

    index_queue = MultiprocessWorkIndexQueue(
        total_indices=euler_angles.shape[0],
        batch_size=orientation_batch_size,
        prefetch_size=10,
        num_processes=len(device),
    )
    global_pbar, device_pbars = setup_progress_tracking(
        index_queue=index_queue,
        unit_scale=defocus_values.shape[0] * pixel_values.shape[0],
        devices=device,
    )
    progress_callback = partial(
        monitor_match_template_progress,
        queue=index_queue,
        pbar=global_pbar,
        device_pbars=device_pbars,
    )

    kwargs_per_device = []
    for d in device:
        kwargs = {
            "index_queue": index_queue,
            "image_dft": image_dft,
            "template_dft": template_dft,
            "euler_angles": euler_angles,
            "projective_filters": projective_filters,
            "defocus_values": defocus_values,
            "pixel_values": pixel_values,
            "orientation_batch_size": orientation_batch_size,
            "num_cuda_streams": num_cuda_streams,
            "backend": backend,
            "device": d,
        }

        kwargs_per_device.append(kwargs)

    result_dict = run_multiprocess_jobs(
        target=_core_match_template_multiprocess_wrapper,
        kwargs_list=kwargs_per_device,
        post_start_callback=progress_callback,
    )

    # Get the aggregated results
    partial_results = [result_dict[i] for i in range(len(kwargs_per_device))]
    aggregated_results = aggregate_distributed_results(partial_results)
    mip = aggregated_results["mip"]
    best_global_index = aggregated_results["best_global_index"]
    correlation_sum = aggregated_results["correlation_sum"]
    correlation_squared_sum = aggregated_results["correlation_squared_sum"]

    # Map from global search index to the best defocus & angles
    best_phi, best_theta, best_psi, best_defocus = decode_global_search_index(
        best_global_index, pixel_values, defocus_values, euler_angles
    )

    mip_scaled = torch.empty_like(mip)
    mip, mip_scaled, correlation_mean, correlation_variance = scale_mip(
        mip=mip,
        mip_scaled=mip_scaled,
        correlation_sum=correlation_sum,
        correlation_squared_sum=correlation_squared_sum,
        total_correlation_positions=total_projections,
    )

    return {
        "mip": mip,
        "scaled_mip": mip_scaled,
        "best_phi": best_phi,
        "best_theta": best_theta,
        "best_psi": best_psi,
        "best_defocus": best_defocus,
        "correlation_mean": correlation_mean,
        "correlation_variance": correlation_variance,
        "total_projections": total_projections,
        "total_orientations": euler_angles.shape[0],
        "total_defocus": defocus_values.shape[0],
    }

monitor_match_template_progress(queue, pbar, device_pbars, poll_interval=1.0)

Helper function for periodic polling of shared queue by tqdm.

This function monitors the progress of template matching and updates progress bars.

Source code in src/leopard_em/backend/core_match_template.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def monitor_match_template_progress(
    queue: "MultiprocessWorkIndexQueue",
    pbar: tqdm.tqdm,
    device_pbars: dict[int, tqdm.tqdm],
    poll_interval: float = 1.0,  # in seconds
) -> None:
    """Helper function for periodic polling of shared queue by tqdm.

    This function monitors the progress of template matching and updates progress bars.
    """
    last_progress = 0
    last_per_device = [0] * len(device_pbars)

    try:
        while True:
            if queue.error_occurred():
                raise RuntimeError("Exiting due to error in another process.")
            progress = queue.get_current_index()
            delta = progress - last_progress

            # Update the global search progress bar
            if delta > 0:
                pbar.update(delta)
                last_progress = progress

            # Update each of the progress bars for each device
            device_counts = queue.get_process_counts()
            for i, dv_pbar in enumerate(device_pbars.values()):
                delta = device_counts[i] - last_per_device[i]
                if delta > 0:
                    dv_pbar.update(delta)
                    last_per_device[i] = device_counts[i]

            # Done with tracking when progress reaches the end of the queue
            if last_progress >= queue.total_indices:
                break

            time.sleep(poll_interval)
    except Exception as e:
        print(f"Error occurred: {e}")
        queue.set_error_flag()
        raise e
    finally:
        # Clean up progress bars
        for dv_pbar in device_pbars.values():
            dv_pbar.close()
        pbar.close()

setup_progress_tracking(index_queue, unit_scale, devices)

Setup global and per-device tqdm progress bars for template matching.

Parameters:

Name Type Description Default
index_queue MultiprocessWorkIndexQueue

The shared work queue tracking global indices.

required
unit_scale Union[float, int]

Scaling factor to apply to units

required
devices list[device]

List of devices to create per-device progress bars for.

required

Returns:

Type Description
tuple[tqdm, dict[int, tqdm]]

Global progress bar and dictionary of per-device progress bars.

Source code in src/leopard_em/backend/core_match_template.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def setup_progress_tracking(
    index_queue: "MultiprocessWorkIndexQueue",
    unit_scale: Union[float, int],
    devices: list[torch.device],
) -> tuple[tqdm.tqdm, dict[int, tqdm.tqdm]]:
    """Setup global and per-device tqdm progress bars for template matching.

    Parameters
    ----------
    index_queue : MultiprocessWorkIndexQueue
        The shared work queue tracking global indices.
    unit_scale : Union[float, int]
        Scaling factor to apply to units
    devices : list[torch.device]
        List of devices to create per-device progress bars for.

    Returns
    -------
    tuple[tqdm.tqdm, dict[int, tqdm.tqdm]]
        Global progress bar and dictionary of per-device progress bars.
    """
    # Global progress bar
    global_pbar = tqdm.tqdm(
        total=index_queue.total_indices,
        desc="2DTM progress",
        dynamic_ncols=True,
        smoothing=0.02,
        unit="corr",
        unit_scale=unit_scale,
    )

    # Per-device progress bars
    device_pbars = {
        i: tqdm.tqdm(
            desc=f"device - {d.type} {d.index}",
            dynamic_ncols=True,
            smoothing=0.02,
            unit="corr",
            unit_scale=unit_scale,
            position=i + 1,  # place below the global bar
            leave=True,
        )
        for i, d in enumerate(devices)
    }

    return global_pbar, device_pbars