Skip to content

core_refine_template

Backend functions related to correlating and refining particle stacks.

combine_euler_angles(angle_a, angle_b)

Helper function for composing rotations defined by two sets of Euler angles.

Source code in src/leopard_em/backend/core_refine_template.py
30
31
32
33
34
35
36
37
38
39
40
41
def combine_euler_angles(angle_a: torch.Tensor, angle_b: torch.Tensor) -> torch.Tensor:
    """Helper function for composing rotations defined by two sets of Euler angles."""
    rotmat_a = roma.euler_to_rotmat(
        EULER_ANGLE_FMT, angle_a, degrees=True, device=angle_a.device
    )
    rotmat_b = roma.euler_to_rotmat(
        EULER_ANGLE_FMT, angle_b, degrees=True, device=angle_b.device
    )
    rotmat_c = roma.rotmat_composition((rotmat_a, rotmat_b))
    euler_angles_c = roma.rotmat_to_euler(EULER_ANGLE_FMT, rotmat_c, degrees=True)

    return euler_angles_c

construct_multi_gpu_refine_template_kwargs(particle_stack_dft, template_dft, euler_angles, euler_angle_offsets, defocus_u, defocus_v, defocus_angle, defocus_offsets, pixel_size_offsets, corr_mean, corr_std, ctf_kwargs, projective_filters, batch_size, devices, num_cuda_streams)

Split particle stack between requested devices.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

Particle stack to split.

required
template_dft Tensor

Template volume.

required
euler_angles Tensor

Euler angles for each particle.

required
euler_angle_offsets Tensor

Euler angle offsets to search over.

required
defocus_u Tensor

Defocus U values for each particle.

required
defocus_v Tensor

Defocus V values for each particle.

required
defocus_angle Tensor

Defocus angle values for each particle.

required
defocus_offsets Tensor

Defocus offsets to search over.

required
pixel_size_offsets Tensor

Pixel size offsets to search over.

required
corr_mean Tensor

Mean of the cross-correlation

required
corr_std Tensor

Standard deviation of the cross-correlation

required
ctf_kwargs dict

CTF calculation parameters.

required
projective_filters Tensor

Projective filters for each particle.

required
batch_size int

Batch size for orientation processing.

required
devices list[device]

List of devices to split across.

required
num_cuda_streams int

Number of CUDA streams to use per device.

required

Returns:

Type Description
list[dict]

List of dictionaries containing the kwargs to call the single-GPU function.

Source code in src/leopard_em/backend/core_refine_template.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def construct_multi_gpu_refine_template_kwargs(
    particle_stack_dft: torch.Tensor,
    template_dft: torch.Tensor,
    euler_angles: torch.Tensor,
    euler_angle_offsets: torch.Tensor,
    defocus_u: torch.Tensor,
    defocus_v: torch.Tensor,
    defocus_angle: torch.Tensor,
    defocus_offsets: torch.Tensor,
    pixel_size_offsets: torch.Tensor,
    corr_mean: torch.Tensor,
    corr_std: torch.Tensor,
    ctf_kwargs: dict,
    projective_filters: torch.Tensor,
    batch_size: int,
    devices: list[torch.device],
    num_cuda_streams: int,
) -> list[dict]:
    """Split particle stack between requested devices.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        Particle stack to split.
    template_dft : torch.Tensor
        Template volume.
    euler_angles : torch.Tensor
        Euler angles for each particle.
    euler_angle_offsets : torch.Tensor
        Euler angle offsets to search over.
    defocus_u : torch.Tensor
        Defocus U values for each particle.
    defocus_v : torch.Tensor
        Defocus V values for each particle.
    defocus_angle : torch.Tensor
        Defocus angle values for each particle.
    defocus_offsets : torch.Tensor
        Defocus offsets to search over.
    pixel_size_offsets : torch.Tensor
        Pixel size offsets to search over.
    corr_mean : torch.Tensor
        Mean of the cross-correlation
    corr_std : torch.Tensor
        Standard deviation of the cross-correlation
    ctf_kwargs : dict
        CTF calculation parameters.
    projective_filters : torch.Tensor
        Projective filters for each particle.
    batch_size : int
        Batch size for orientation processing.
    devices : list[torch.device]
        List of devices to split across.
    num_cuda_streams : int
        Number of CUDA streams to use per device.

    Returns
    -------
    list[dict]
        List of dictionaries containing the kwargs to call the single-GPU function.
    """
    num_devices = len(devices)
    kwargs_per_device = []
    num_particles = particle_stack_dft.shape[0]

    # Calculate how many particles to assign to each device
    particles_per_device = [num_particles // num_devices] * num_devices
    # Distribute remaining particles
    for i in range(num_particles % num_devices):
        particles_per_device[i] += 1

    # Split the particle stack across devices
    start_idx = 0
    for device_idx, num_device_particles in enumerate(particles_per_device):
        if num_device_particles == 0:
            continue

        end_idx = start_idx + num_device_particles
        device = devices[device_idx]

        # Get particle indices for this device
        particle_indices = torch.arange(start_idx, end_idx)

        # Split tensors for this device. All these tensors are per-particle, that is
        # the i-th element in each tensor corresponds to the i-th particle in the stack.
        device_particle_stack_dft = particle_stack_dft[start_idx:end_idx]
        device_euler_angles = euler_angles[start_idx:end_idx]
        device_defocus_u = defocus_u[start_idx:end_idx]
        device_defocus_v = defocus_v[start_idx:end_idx]
        device_defocus_angle = defocus_angle[start_idx:end_idx]
        device_projective_filters = projective_filters[start_idx:end_idx]

        kwargs = {
            "particle_stack_dft": device_particle_stack_dft,
            "particle_indices": particle_indices,
            "template_dft": template_dft,
            "euler_angles": device_euler_angles,
            "euler_angle_offsets": euler_angle_offsets,
            "defocus_u": device_defocus_u,
            "defocus_v": device_defocus_v,
            "defocus_angle": device_defocus_angle,
            "defocus_offsets": defocus_offsets,
            "pixel_size_offsets": pixel_size_offsets,
            "corr_mean": corr_mean,
            "corr_std": corr_std,
            "projective_filters": device_projective_filters,
            "ctf_kwargs": ctf_kwargs,
            "batch_size": batch_size,
            "num_cuda_streams": num_cuda_streams,
            "device": device,
        }

        kwargs_per_device.append(kwargs)
        start_idx = end_idx

    return kwargs_per_device

core_refine_template(particle_stack_dft, template_dft, euler_angles, euler_angle_offsets, defocus_offsets, defocus_u, defocus_v, defocus_angle, pixel_size_offsets, corr_mean, corr_std, ctf_kwargs, projective_filters, device, batch_size=32, num_cuda_streams=1)

Core function to refine orientations and defoci of a set of particles.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

The stack of particle real-Fourier transformed and un-fftshifted images. Shape of (N, H, W).

required
template_dft Tensor

The template volume to extract central slices from. Real-Fourier transformed and fftshifted.

required
euler_angles Tensor

The Euler angles for each particle in the stack. Shape of (N, 3).

required
euler_angle_offsets Tensor

The Euler angle offsets to apply to each particle. Shape of (k, 3).

required
defocus_u Tensor

The defocus along the major axis for each particle in the stack. Shape of (N,).

required
defocus_v Tensor

The defocus along the minor for each particle in the stack. Shape of (N,).

required
defocus_angle Tensor

The defocus astigmatism angle for each particle in the stack. Shape of (N,). Is the same as the defocus for the micrograph the particle came from.

required
defocus_offsets Tensor

The defocus offsets to search over for each particle. Shape of (l,).

required
pixel_size_offsets Tensor

The pixel size offsets to search over for each particle. Shape of (m,).

required
corr_mean Tensor

The mean of the cross-correlation values from the full orientation search for the pixels around the center of the particle. Shape of (H - h + 1, W - w + 1).

required
corr_std Tensor

The standard deviation of the cross-correlation values from the full orientation search for the pixels around the center of the particle. Shape of (H - h + 1, W - w + 1).

required
ctf_kwargs dict

Keyword arguments to pass to the CTF calculation function.

required
projective_filters Tensor

Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).

required
device device | list[device]

Device or list of devices to use for processing.

required
batch_size int

The number of cross-correlations to process in one batch, defaults to 32.

32
num_cuda_streams int

Number of CUDA streams to use for parallel processing. Defaults to 1.

1

Returns:

Type Description
dict[str, Tensor]

Dictionary containing the refined parameters for all particles.

Source code in src/leopard_em/backend/core_refine_template.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def core_refine_template(
    particle_stack_dft: torch.Tensor,  # (N, H, W)
    template_dft: torch.Tensor,  # (d, h, w)
    euler_angles: torch.Tensor,  # (N, 3)
    euler_angle_offsets: torch.Tensor,  # (k, 3)
    defocus_offsets: torch.Tensor,  # (l,)
    defocus_u: torch.Tensor,  # (N,)
    defocus_v: torch.Tensor,  # (N,)
    defocus_angle: torch.Tensor,  # (N,)
    pixel_size_offsets: torch.Tensor,  # (m,)
    corr_mean: torch.Tensor,  # (N, H - h + 1, W - w + 1)
    corr_std: torch.Tensor,  # (N, H - h + 1, W - w + 1)
    ctf_kwargs: dict,
    projective_filters: torch.Tensor,  # (N, h, w)
    device: torch.device | list[torch.device],
    batch_size: int = 32,
    num_cuda_streams: int = 1,
) -> dict[str, torch.Tensor]:
    """Core function to refine orientations and defoci of a set of particles.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        The stack of particle real-Fourier transformed and un-fftshifted images.
        Shape of (N, H, W).
    template_dft : torch.Tensor
        The template volume to extract central slices from. Real-Fourier transformed
        and fftshifted.
    euler_angles : torch.Tensor
        The Euler angles for each particle in the stack. Shape of (N, 3).
    euler_angle_offsets : torch.Tensor
        The Euler angle offsets to apply to each particle. Shape of (k, 3).
    defocus_u : torch.Tensor
        The defocus along the major axis for each particle in the stack. Shape of (N,).
    defocus_v : torch.Tensor
        The defocus along the minor for each particle in the stack. Shape of (N,).
    defocus_angle : torch.Tensor
        The defocus astigmatism angle for each particle in the stack. Shape of (N,).
        Is the same as the defocus for the micrograph the particle came from.
    defocus_offsets : torch.Tensor
        The defocus offsets to search over for each particle. Shape of (l,).
    pixel_size_offsets : torch.Tensor
        The pixel size offsets to search over for each particle. Shape of (m,).
    corr_mean : torch.Tensor
        The mean of the cross-correlation values from the full orientation search
        for the pixels around the center of the particle.
        Shape of (H - h + 1, W - w + 1).
    corr_std : torch.Tensor
        The standard deviation of the cross-correlation values from the full
        orientation search for the pixels around the center of the particle.
        Shape of (H - h + 1, W - w + 1).
    ctf_kwargs : dict
        Keyword arguments to pass to the CTF calculation function.
    projective_filters : torch.Tensor
        Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).
    device : torch.device | list[torch.device]
        Device or list of devices to use for processing.
    batch_size : int, optional
        The number of cross-correlations to process in one batch, defaults to 32.
    num_cuda_streams : int, optional
        Number of CUDA streams to use for parallel processing. Defaults to 1.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing the refined parameters for all particles.
    """
    # Convert single device to list for consistent handling
    if isinstance(device, torch.device):
        device = [device]

    ###########################################
    ### Split particle stack across devices ###
    ###########################################

    kwargs_per_device = construct_multi_gpu_refine_template_kwargs(
        particle_stack_dft=particle_stack_dft,
        template_dft=template_dft,
        euler_angles=euler_angles,
        euler_angle_offsets=euler_angle_offsets,
        defocus_u=defocus_u,
        defocus_v=defocus_v,
        defocus_angle=defocus_angle,
        defocus_offsets=defocus_offsets,
        pixel_size_offsets=pixel_size_offsets,
        corr_mean=corr_mean,
        corr_std=corr_std,
        ctf_kwargs=ctf_kwargs,
        projective_filters=projective_filters,
        batch_size=batch_size,
        devices=device,
        num_cuda_streams=num_cuda_streams,
    )

    results = run_multiprocess_jobs(
        target=_core_refine_template_single_gpu,
        kwargs_list=kwargs_per_device,
    )

    # Synchronize all devices to ensure all computations are complete
    for dev in device:
        if dev.type == "cuda":
            torch.cuda.synchronize(dev)

    # Shape information for offset calculations
    _, img_h, img_w = particle_stack_dft.shape
    _, template_h, template_w = template_dft.shape
    # account for RFFT
    img_w = 2 * (img_w - 1)
    template_w = 2 * (template_w - 1)

    # Concatenate results from all devices
    refined_cross_correlation = torch.cat(
        [torch.from_numpy(r["refined_cross_correlation"]) for r in results.values()]
    )
    refined_z_score = torch.cat(
        [torch.from_numpy(r["refined_z_score"]) for r in results.values()]
    )
    refined_euler_angles = torch.cat(
        [torch.from_numpy(r["refined_euler_angles"]) for r in results.values()]
    )
    refined_defocus_offset = torch.cat(
        [torch.from_numpy(r["refined_defocus_offset"]) for r in results.values()]
    )
    refined_pixel_size_offset = torch.cat(
        [torch.from_numpy(r["refined_pixel_size_offset"]) for r in results.values()]
    )
    refined_pos_y = torch.cat(
        [torch.from_numpy(r["refined_pos_y"]) for r in results.values()]
    )
    refined_pos_x = torch.cat(
        [torch.from_numpy(r["refined_pos_x"]) for r in results.values()]
    )

    # Ensure the results are sorted back to the original particle order
    # (If particles were split across devices, we need to reorder the results)
    particle_indices = torch.cat(
        [torch.from_numpy(r["particle_indices"]) for r in results.values()]
    )
    angle_idx = torch.cat([torch.from_numpy(r["angle_idx"]) for r in results.values()])
    sort_indices = torch.argsort(particle_indices)

    refined_cross_correlation = refined_cross_correlation[sort_indices]
    refined_z_score = refined_z_score[sort_indices]
    refined_euler_angles = refined_euler_angles[sort_indices]
    refined_defocus_offset = refined_defocus_offset[sort_indices]
    refined_pixel_size_offset = refined_pixel_size_offset[sort_indices]
    refined_pos_y = refined_pos_y[sort_indices]
    refined_pos_x = refined_pos_x[sort_indices]
    angle_idx = angle_idx[sort_indices]

    # Offset refined_pos_{x,y} by the extracted box size (same as original)
    refined_pos_y -= (img_h - template_h + 1) // 2
    refined_pos_x -= (img_w - template_w + 1) // 2

    return {
        "refined_cross_correlation": refined_cross_correlation,
        "refined_z_score": refined_z_score,
        "refined_euler_angles": refined_euler_angles,
        "refined_defocus_offset": refined_defocus_offset,
        "refined_pixel_size_offset": refined_pixel_size_offset,
        "refined_pos_y": refined_pos_y,
        "refined_pos_x": refined_pos_x,
        "angle_idx": angle_idx,
    }

cross_correlate_particle_stack(particle_stack_dft, template_dft, rotation_matrices, projective_filters, mode='valid', batch_size=1024)

Cross-correlate a stack of particle images against a template.

Here, the argument 'particle_stack_dft' is a set of RFFT-ed particle images with necessary filtering already applied. The zeroth dimension corresponds to unique particles.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

The stack of particle real-Fourier transformed and un-fftshifted images. Shape of (N, H, W).

required
template_dft Tensor

The template volume to extract central slices from. Real-Fourier transformed and fftshifted.

required
rotation_matrices Tensor

The orientations of the particles to take the Fourier slices of, as a long list of rotation matrices. Shape of (N, 3, 3).

required
projective_filters Tensor

Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).

required
mode Literal['valid', 'same']

Correlation mode to use, by default "valid". If "valid", the output will be the valid cross-correlation of the inputs. If "same", the output will be the same shape as the input particle stack.

'valid'
batch_size int

The number of particle images to cross-correlate at once. Default is 1024. Larger sizes will consume more memory. If -1, then the entire stack will be cross-correlated at once.

1024

Returns:

Type Description
Tensor

The cross-correlation of the particle stack with the template. Shape will depend on the mode used. If "valid", the output will be (N, H-h+1, W-w+1). If "same", the output will be (N, H, W).

Raises:

Type Description
ValueError

If the mode is not "valid" or "same".

Source code in src/leopard_em/backend/core_refine_template.py
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
def cross_correlate_particle_stack(
    particle_stack_dft: torch.Tensor,  # (N, H, W)
    template_dft: torch.Tensor,  # (d, h, w)
    rotation_matrices: torch.Tensor,  # (N, 3, 3)
    projective_filters: torch.Tensor,  # (N, h, w)
    mode: Literal["valid", "same"] = "valid",
    batch_size: int = 1024,
) -> torch.Tensor:
    """Cross-correlate a stack of particle images against a template.

    Here, the argument 'particle_stack_dft' is a set of RFFT-ed particle images with
    necessary filtering already applied. The zeroth dimension corresponds to unique
    particles.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        The stack of particle real-Fourier transformed and un-fftshifted images.
        Shape of (N, H, W).
    template_dft : torch.Tensor
        The template volume to extract central slices from. Real-Fourier transformed
        and fftshifted.
    rotation_matrices : torch.Tensor
        The orientations of the particles to take the Fourier slices of, as a long
        list of rotation matrices. Shape of (N, 3, 3).
    projective_filters : torch.Tensor
        Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).
    mode : Literal["valid", "same"], optional
        Correlation mode to use, by default "valid". If "valid", the output will be
        the valid cross-correlation of the inputs. If "same", the output will be the
        same shape as the input particle stack.
    batch_size : int, optional
        The number of particle images to cross-correlate at once. Default is 1024.
        Larger sizes will consume more memory. If -1, then the entire stack will be
        cross-correlated at once.

    Returns
    -------
    torch.Tensor
        The cross-correlation of the particle stack with the template. Shape will depend
        on the mode used. If "valid", the output will be (N, H-h+1, W-w+1). If "same",
        the output will be (N, H, W).

    Raises
    ------
    ValueError
        If the mode is not "valid" or "same".
    """
    # Helpful constants for later use
    device = particle_stack_dft.device
    num_particles, image_h, image_w = particle_stack_dft.shape
    _, template_h, template_w = template_dft.shape
    # account for RFFT
    image_w = 2 * (image_w - 1)
    template_w = 2 * (template_w - 1)

    if batch_size == -1:
        batch_size = num_particles

    if mode == "valid":
        output_shape = (
            num_particles,
            image_h - template_h + 1,
            image_w - template_w + 1,
        )
    elif mode == "same":
        output_shape = (num_particles, image_h, image_w)
    else:
        raise ValueError(f"Invalid mode: {mode}. Must be 'valid' or 'same'.")

    out_correlation = torch.zeros(output_shape, device=device)

    # Loop over the particle stack in batches
    for i in range(0, num_particles, batch_size):
        batch_particles_dft = particle_stack_dft[i : i + batch_size]
        batch_rotation_matrices = rotation_matrices[i : i + batch_size]
        batch_projective_filters = projective_filters[i : i + batch_size]

        # Extract the Fourier slice and apply the projective filters
        fourier_slice = extract_central_slices_rfft_3d(
            volume_rfft=template_dft,
            image_shape=(template_h,) * 3,
            rotation_matrices=batch_rotation_matrices,
        )
        fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
        fourier_slice[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)
        fourier_slice *= -1  # flip contrast
        fourier_slice *= batch_projective_filters

        # Inverse Fourier transform and normalize the projection
        projections = torch.fft.irfftn(fourier_slice, dim=(-2, -1))
        projections = torch.fft.ifftshift(projections, dim=(-2, -1))
        projections = normalize_template_projection(
            projections, (template_h, template_w), (image_h, image_w)
        )

        # Padded forward FFT and cross-correlate
        projections_dft = torch.fft.rfftn(
            projections, dim=(-2, -1), s=(image_h, image_w)
        )
        projections_dft = batch_particles_dft * projections_dft.conj()
        cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))

        # Handle the output shape
        cross_correlation = handle_correlation_mode(
            cross_correlation, output_shape, mode
        )

        out_correlation[i : i + batch_size] = cross_correlation

    return out_correlation