Skip to content

core_match_template

Pure PyTorch implementation of whole orientation search backend.

construct_multi_gpu_match_template_kwargs(image_dft, template_dft, euler_angles, projective_filters, defocus_values, pixel_values, orientation_batch_size, devices)

Split orientations between requested devices.

See the core_match_template function for further descriptions of the input parameters.

Parameters:

Name Type Description Default
image_dft Tensor

dft of image

required
template_dft Tensor

dft of template

required
euler_angles Tensor

euler angles to search

required
projective_filters Tensor

filters to apply to each projection

required
defocus_values Tensor

corresponding defocus values for each filter

required
pixel_values Tensor

corresponding pixel size values for each filter

required
orientation_batch_size int

number of projections to calculate at once

required
devices list[device]

list of devices to split the orientations across

required

Returns:

Type Description
list[dict[str, Tensor | int]]

List of dictionaries containing the kwargs to call the single-GPU function. Each index in the list corresponds to a different device, and all tensors in the dictionary have been allocated to that device.

Source code in src/leopard_em/backend/core_match_template.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def construct_multi_gpu_match_template_kwargs(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,
    euler_angles: torch.Tensor,
    projective_filters: torch.Tensor,
    defocus_values: torch.Tensor,
    pixel_values: torch.Tensor,
    orientation_batch_size: int,
    devices: list[torch.device],
) -> list[dict[str, torch.Tensor | int]]:
    """Split orientations between requested devices.

    See the `core_match_template` function for further descriptions of the
    input parameters.

    Parameters
    ----------
    image_dft : torch.Tensor
        dft of image
    template_dft : torch.Tensor
        dft of template
    euler_angles : torch.Tensor
        euler angles to search
    projective_filters : torch.Tensor
        filters to apply to each projection
    defocus_values : torch.Tensor
        corresponding defocus values for each filter
    pixel_values : torch.Tensor
        corresponding pixel size values for each filter
    orientation_batch_size : int
        number of projections to calculate at once
    devices : list[torch.device]
        list of devices to split the orientations across

    Returns
    -------
    list[dict[str, torch.Tensor | int]]
        List of dictionaries containing the kwargs to call the single-GPU
        function. Each index in the list corresponds to a different device,
        and all tensors in the dictionary have been allocated to that device.
    """
    kwargs_per_device = []

    # Split the euler angles across devices
    euler_angles_split = euler_angles.chunk(len(devices))

    for device, euler_angles_device in zip(devices, euler_angles_split):
        # Allocate and construct the kwargs for this device
        kwargs = {
            "image_dft": image_dft.to(device),
            "template_dft": template_dft.to(device),
            "euler_angles": euler_angles_device.to(device),
            "projective_filters": projective_filters.to(device),
            "defocus_values": defocus_values.to(device),
            "pixel_values": pixel_values.to(device),
            "orientation_batch_size": orientation_batch_size,
        }

        kwargs_per_device.append(kwargs)

    return kwargs_per_device

core_match_template(image_dft, template_dft, ctf_filters, whitening_filter_template, defocus_values, pixel_values, euler_angles, device, orientation_batch_size=1)

Core function for performing the whole-orientation search.

With the RFFT, the last dimension (fastest dimension) is half the width of the input, hence the shape of W // 2 + 1 instead of W for some of the input parameters.

Parameters:

Name Type Description Default
image_dft Tensor

Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1).

required
template_dft Tensor

Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1). where l is the number of slices.

required
ctf_filters Tensor

Stack of CTF filters at different defocus values to use in the search. Has shape (defocus_batch, h, w // 2 + 1).

required
whitening_filter_template Tensor

Whitening filter for the template volume. Has shape (h, w // 2 + 1). Gets multiplied with the ctf filters to create a filter stack.

required
euler_angles Tensor

Euler angles (in 'ZYZ' convention) to search over. Has shape (orientations, 3).

required
defocus_values Tensor

What defoucs values correspond with the CTF filters. Has shape (defocus_batch,).

required
pixel_values Tensor

What pixel size values correspond with the CTF filters. Has shape (pixel_size_batch,).

required
device device | list[device]

Device or devices to split computation across.

required
orientation_batch_size int

Number of projections to calculate at once, on each device

1

Returns:

Type Description
dict[str, Tensor]

Dictionary containing the following key, value pairs:

- "mip": Maximum intensity projection of the cross-correlation values across
  orientation and defocus search space.
- "scaled_mip": Z-score scaled MIP of the cross-correlation values.
- "best_phi": Best phi angle for each pixel.
- "best_theta": Best theta angle for each pixel.
- "best_psi": Best psi angle for each pixel.
- "best_defocus": Best defocus value for each pixel.
- "best_pixel_size": Best pixel size value for each pixel.
- "correlation_sum": Sum of cross-correlation values for each pixel.
- "correlation_squared_sum": Sum of squared cross-correlation values for
  each pixel.
- "total_projections": Total number of projections calculated.
- "total_orientations": Total number of orientations searched.
- "total_defocus": Total number of defocus values searched.
Source code in src/leopard_em/backend/core_match_template.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def core_match_template(
    image_dft: torch.Tensor,
    template_dft: torch.Tensor,  # already fftshifted
    ctf_filters: torch.Tensor,
    whitening_filter_template: torch.Tensor,
    defocus_values: torch.Tensor,
    pixel_values: torch.Tensor,
    euler_angles: torch.Tensor,
    device: torch.device | list[torch.device],
    orientation_batch_size: int = 1,
) -> dict[str, torch.Tensor]:
    """Core function for performing the whole-orientation search.

    With the RFFT, the last dimension (fastest dimension) is half the width
    of the input, hence the shape of W // 2 + 1 instead of W for some of the
    input parameters.

    Parameters
    ----------
    image_dft : torch.Tensor
        Real-fourier transform (RFFT) of the image with large image filters
        already applied. Has shape (H, W // 2 + 1).
    template_dft : torch.Tensor
        Real-fourier transform (RFFT) of the template volume to take Fourier
        slices from. Has shape (l, h, w // 2 + 1). where l is the number of
        slices.
    ctf_filters : torch.Tensor
        Stack of CTF filters at different defocus values to use in the search.
        Has shape (defocus_batch, h, w // 2 + 1).
    whitening_filter_template : torch.Tensor
        Whitening filter for the template volume. Has shape (h, w // 2 + 1).
        Gets multiplied with the ctf filters to create a filter stack.
    euler_angles : torch.Tensor
        Euler angles (in 'ZYZ' convention) to search over. Has shape
        (orientations, 3).
    defocus_values : torch.Tensor
        What defoucs values correspond with the CTF filters. Has shape
        (defocus_batch,).
    pixel_values : torch.Tensor
        What pixel size values correspond with the CTF filters. Has shape
        (pixel_size_batch,).
    device : torch.device | list[torch.device]
        Device or devices to split computation across.
    orientation_batch_size : int, optional
        Number of projections to calculate at once, on each device

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing the following key, value pairs:

            - "mip": Maximum intensity projection of the cross-correlation values across
              orientation and defocus search space.
            - "scaled_mip": Z-score scaled MIP of the cross-correlation values.
            - "best_phi": Best phi angle for each pixel.
            - "best_theta": Best theta angle for each pixel.
            - "best_psi": Best psi angle for each pixel.
            - "best_defocus": Best defocus value for each pixel.
            - "best_pixel_size": Best pixel size value for each pixel.
            - "correlation_sum": Sum of cross-correlation values for each pixel.
            - "correlation_squared_sum": Sum of squared cross-correlation values for
              each pixel.
            - "total_projections": Total number of projections calculated.
            - "total_orientations": Total number of orientations searched.
            - "total_defocus": Total number of defocus values searched.
    """
    ##############################################################
    ### Pre-multiply the whitening filter with the CTF filters ###
    ##############################################################

    projective_filters = ctf_filters * whitening_filter_template[None, None, ...]

    #########################################
    ### Split orientations across devices ###
    #########################################

    if isinstance(device, torch.device):
        device = [device]

    kwargs_per_device = construct_multi_gpu_match_template_kwargs(
        image_dft=image_dft,
        template_dft=template_dft,
        euler_angles=euler_angles,
        projective_filters=projective_filters,
        defocus_values=defocus_values,
        pixel_values=pixel_values,
        orientation_batch_size=orientation_batch_size,
        devices=device,
    )

    result_dict = run_multiprocess_jobs(
        target=_core_match_template_single_gpu,
        kwargs_list=kwargs_per_device,
    )

    # Get the aggregated results
    partial_results = [result_dict[i] for i in range(len(kwargs_per_device))]
    aggregated_results = aggregate_distributed_results(partial_results)
    mip = aggregated_results["mip"]
    best_phi = aggregated_results["best_phi"]
    best_theta = aggregated_results["best_theta"]
    best_psi = aggregated_results["best_psi"]
    best_defocus = aggregated_results["best_defocus"]
    correlation_sum = aggregated_results["correlation_sum"]
    correlation_squared_sum = aggregated_results["correlation_squared_sum"]
    total_projections = aggregated_results["total_projections"]

    mip_scaled = torch.empty_like(mip)
    mip, mip_scaled, correlation_mean, correlation_variance = scale_mip(
        mip=mip,
        mip_scaled=mip_scaled,
        correlation_sum=correlation_sum,
        correlation_squared_sum=correlation_squared_sum,
        total_correlation_positions=total_projections,
    )

    return {
        "mip": mip,
        "scaled_mip": mip_scaled,
        "best_phi": best_phi,
        "best_theta": best_theta,
        "best_psi": best_psi,
        "best_defocus": best_defocus,
        "correlation_mean": correlation_mean,
        "correlation_variance": correlation_variance,
        "total_projections": total_projections,
        "total_orientations": euler_angles.shape[0],
        "total_defocus": defocus_values.shape[0],
    }