Skip to content

core_refine_template

Backend functions related to correlating and refining particle stacks.

combine_euler_angles(angle_a, angle_b)

Helper function for composing rotations defined by two sets of Euler angles.

Source code in src/leopard_em/backend/core_refine_template.py
29
30
31
32
33
34
35
36
37
38
39
40
def combine_euler_angles(angle_a: torch.Tensor, angle_b: torch.Tensor) -> torch.Tensor:
    """Helper function for composing rotations defined by two sets of Euler angles."""
    rotmat_a = roma.euler_to_rotmat(
        EULER_ANGLE_FMT, angle_a, degrees=True, device=angle_a.device
    )
    rotmat_b = roma.euler_to_rotmat(
        EULER_ANGLE_FMT, angle_b, degrees=True, device=angle_b.device
    )
    rotmat_c = roma.rotmat_composition((rotmat_a, rotmat_b))
    euler_angles_c = roma.rotmat_to_euler(EULER_ANGLE_FMT, rotmat_c, degrees=True)

    return euler_angles_c

construct_multi_gpu_refine_template_kwargs(particle_stack_dft, template_dft, euler_angles, euler_angle_offsets, defocus_u, defocus_v, defocus_angle, defocus_offsets, pixel_size_offsets, corr_mean, corr_std, ctf_kwargs, projective_filters, batch_size, devices)

Split particle stack between requested devices.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

Particle stack to split.

required
template_dft Tensor

Template volume.

required
euler_angles Tensor

Euler angles for each particle.

required
euler_angle_offsets Tensor

Euler angle offsets to search over.

required
defocus_u Tensor

Defocus U values for each particle.

required
defocus_v Tensor

Defocus V values for each particle.

required
defocus_angle Tensor

Defocus angle values for each particle.

required
defocus_offsets Tensor

Defocus offsets to search over.

required
pixel_size_offsets Tensor

Pixel size offsets to search over.

required
corr_mean Tensor

Mean of the cross-correlation

required
corr_std Tensor

Standard deviation of the cross-correlation

required
ctf_kwargs dict

CTF calculation parameters.

required
projective_filters Tensor

Projective filters for each particle.

required
batch_size int

Batch size for orientation processing.

required
devices list[device]

List of devices to split across.

required

Returns:

Type Description
list[dict]

List of dictionaries containing the kwargs to call the single-GPU function.

Source code in src/leopard_em/backend/core_refine_template.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def construct_multi_gpu_refine_template_kwargs(
    particle_stack_dft: torch.Tensor,
    template_dft: torch.Tensor,
    euler_angles: torch.Tensor,
    euler_angle_offsets: torch.Tensor,
    defocus_u: torch.Tensor,
    defocus_v: torch.Tensor,
    defocus_angle: torch.Tensor,
    defocus_offsets: torch.Tensor,
    pixel_size_offsets: torch.Tensor,
    corr_mean: torch.Tensor,
    corr_std: torch.Tensor,
    ctf_kwargs: dict,
    projective_filters: torch.Tensor,
    batch_size: int,
    devices: list[torch.device],
) -> list[dict]:
    """Split particle stack between requested devices.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        Particle stack to split.
    template_dft : torch.Tensor
        Template volume.
    euler_angles : torch.Tensor
        Euler angles for each particle.
    euler_angle_offsets : torch.Tensor
        Euler angle offsets to search over.
    defocus_u : torch.Tensor
        Defocus U values for each particle.
    defocus_v : torch.Tensor
        Defocus V values for each particle.
    defocus_angle : torch.Tensor
        Defocus angle values for each particle.
    defocus_offsets : torch.Tensor
        Defocus offsets to search over.
    pixel_size_offsets : torch.Tensor
        Pixel size offsets to search over.
    corr_mean : torch.Tensor
        Mean of the cross-correlation
    corr_std : torch.Tensor
        Standard deviation of the cross-correlation
    ctf_kwargs : dict
        CTF calculation parameters.
    projective_filters : torch.Tensor
        Projective filters for each particle.
    batch_size : int
        Batch size for orientation processing.
    devices : list[torch.device]
        List of devices to split across.

    Returns
    -------
    list[dict]
        List of dictionaries containing the kwargs to call the single-GPU function.
    """
    num_devices = len(devices)
    kwargs_per_device = []
    num_particles = particle_stack_dft.shape[0]

    # Calculate how many particles to assign to each device
    particles_per_device = [num_particles // num_devices] * num_devices
    # Distribute remaining particles
    for i in range(num_particles % num_devices):
        particles_per_device[i] += 1

    # Split the particle stack across devices
    start_idx = 0
    for device_idx, num_device_particles in enumerate(particles_per_device):
        if num_device_particles == 0:
            continue

        end_idx = start_idx + num_device_particles
        device = devices[device_idx]

        # Get particle indices for this device
        particle_indices = torch.arange(start_idx, end_idx)

        # Split tensors for this device
        device_particle_stack_dft = particle_stack_dft[start_idx:end_idx].to(device)
        device_euler_angles = euler_angles[start_idx:end_idx].to(device)
        device_defocus_u = defocus_u[start_idx:end_idx].to(device)
        device_defocus_v = defocus_v[start_idx:end_idx].to(device)
        device_defocus_angle = defocus_angle[start_idx:end_idx].to(device)
        device_projective_filters = projective_filters[start_idx:end_idx].to(device)

        # These are shared across all particles
        device_template_dft = template_dft.to(device)
        device_euler_angle_offsets = euler_angle_offsets.to(device)
        device_defocus_offsets = defocus_offsets.to(device)
        device_pixel_size_offsets = pixel_size_offsets.to(device)
        device_corr_mean = corr_mean.to(device)
        device_corr_std = corr_std.to(device)

        kwargs = {
            "particle_stack_dft": device_particle_stack_dft,
            "particle_indices": particle_indices.cpu().numpy(),
            "template_dft": device_template_dft,
            "euler_angles": device_euler_angles,
            "euler_angle_offsets": device_euler_angle_offsets,
            "defocus_u": device_defocus_u,
            "defocus_v": device_defocus_v,
            "defocus_angle": device_defocus_angle,
            "defocus_offsets": device_defocus_offsets,
            "pixel_size_offsets": device_pixel_size_offsets,
            "corr_mean": device_corr_mean,
            "corr_std": device_corr_std,
            "ctf_kwargs": ctf_kwargs,
            "projective_filters": device_projective_filters,
            "batch_size": batch_size,
        }

        kwargs_per_device.append(kwargs)
        start_idx = end_idx

    return kwargs_per_device

core_refine_template(particle_stack_dft, template_dft, euler_angles, euler_angle_offsets, defocus_offsets, defocus_u, defocus_v, defocus_angle, pixel_size_offsets, corr_mean, corr_std, ctf_kwargs, projective_filters, device=None, batch_size=64)

Core function to refine orientations and defoci of a set of particles.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

The stack of particle real-Fourier transformed and un-fftshifted images. Shape of (N, H, W).

required
template_dft Tensor

The template volume to extract central slices from. Real-Fourier transformed and fftshifted.

required
euler_angles Tensor

The Euler angles for each particle in the stack. Shape of (N, 3).

required
euler_angle_offsets Tensor

The Euler angle offsets to apply to each particle. Shape of (k, 3).

required
defocus_u Tensor

The defocus along the major axis for each particle in the stack. Shape of (N,).

required
defocus_v Tensor

The defocus along the minor for each particle in the stack. Shape of (N,).

required
defocus_angle Tensor

The defocus astigmatism angle for each particle in the stack. Shape of (N,). Is the same as the defocus for the micrograph the particle came from.

required
defocus_offsets Tensor

The defocus offsets to search over for each particle. Shape of (l,).

required
pixel_size_offsets Tensor

The pixel size offsets to search over for each particle. Shape of (m,).

required
corr_mean Tensor

The mean of the cross-correlation values from the full orientation search for the pixels around the center of the particle. Shape of (H - h + 1, W - w + 1).

required
corr_std Tensor

The standard deviation of the cross-correlation values from the full orientation search for the pixels around the center of the particle. Shape of (H - h + 1, W - w + 1).

required
ctf_kwargs dict

Keyword arguments to pass to the CTF calculation function.

required
projective_filters Tensor

Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).

required
device device | list[device]

Device or list of devices to use for processing.

None
batch_size int

The number of orientations to process at once. Default is 64.

64

Returns:

Type Description
dict[str, Tensor]

Dictionary containing the refined parameters for all particles.

Source code in src/leopard_em/backend/core_refine_template.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def core_refine_template(
    particle_stack_dft: torch.Tensor,  # (N, H, W)
    template_dft: torch.Tensor,  # (d, h, w)
    euler_angles: torch.Tensor,  # (N, 3)
    euler_angle_offsets: torch.Tensor,  # (k, 3)
    defocus_offsets: torch.Tensor,  # (l,)
    defocus_u: torch.Tensor,  # (N,)
    defocus_v: torch.Tensor,  # (N,)
    defocus_angle: torch.Tensor,  # (N,)
    pixel_size_offsets: torch.Tensor,  # (m,)
    corr_mean: torch.Tensor,  # (N, H - h + 1, W - w + 1)
    corr_std: torch.Tensor,  # (N, H - h + 1, W - w + 1)
    ctf_kwargs: dict,
    projective_filters: torch.Tensor,  # (N, h, w)
    device: torch.device | list[torch.device] = None,
    batch_size: int = 64,
) -> dict[str, torch.Tensor]:
    """Core function to refine orientations and defoci of a set of particles.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        The stack of particle real-Fourier transformed and un-fftshifted images.
        Shape of (N, H, W).
    template_dft : torch.Tensor
        The template volume to extract central slices from. Real-Fourier transformed
        and fftshifted.
    euler_angles : torch.Tensor
        The Euler angles for each particle in the stack. Shape of (N, 3).
    euler_angle_offsets : torch.Tensor
        The Euler angle offsets to apply to each particle. Shape of (k, 3).
    defocus_u : torch.Tensor
        The defocus along the major axis for each particle in the stack. Shape of (N,).
    defocus_v : torch.Tensor
        The defocus along the minor for each particle in the stack. Shape of (N,).
    defocus_angle : torch.Tensor
        The defocus astigmatism angle for each particle in the stack. Shape of (N,).
        Is the same as the defocus for the micrograph the particle came from.
    defocus_offsets : torch.Tensor
        The defocus offsets to search over for each particle. Shape of (l,).
    pixel_size_offsets : torch.Tensor
        The pixel size offsets to search over for each particle. Shape of (m,).
    corr_mean : torch.Tensor
        The mean of the cross-correlation values from the full orientation search
        for the pixels around the center of the particle.
        Shape of (H - h + 1, W - w + 1).
    corr_std : torch.Tensor
        The standard deviation of the cross-correlation values from the full
        orientation search for the pixels around the center of the particle.
        Shape of (H - h + 1, W - w + 1).
    ctf_kwargs : dict
        Keyword arguments to pass to the CTF calculation function.
    projective_filters : torch.Tensor
        Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).
    device : torch.device | list[torch.device], optional
        Device or list of devices to use for processing.
    batch_size : int, optional
        The number of orientations to process at once. Default is 64.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing the refined parameters for all particles.
    """
    # If no device specified, use the device  gpu 0
    if device is None:
        device = [torch.device("cuda:0")]

    # Convert single device to list for consistent handling
    if isinstance(device, torch.device):
        device = [device]

    ###########################################
    ### Split particle stack across devices ###
    ###########################################
    kwargs_per_device = construct_multi_gpu_refine_template_kwargs(
        particle_stack_dft=particle_stack_dft,
        template_dft=template_dft,
        euler_angles=euler_angles,
        euler_angle_offsets=euler_angle_offsets,
        defocus_u=defocus_u,
        defocus_v=defocus_v,
        defocus_angle=defocus_angle,
        defocus_offsets=defocus_offsets,
        pixel_size_offsets=pixel_size_offsets,
        corr_mean=corr_mean,
        corr_std=corr_std,
        ctf_kwargs=ctf_kwargs,
        projective_filters=projective_filters,
        batch_size=batch_size,
        devices=device,
    )

    results = run_multiprocess_jobs(
        target=_core_refine_template_single_gpu,
        kwargs_list=kwargs_per_device,
    )

    # Shape information for offset calculations
    _, img_h, img_w = particle_stack_dft.shape
    _, template_h, template_w = template_dft.shape
    # account for RFFT
    img_w = 2 * (img_w - 1)
    template_w = 2 * (template_w - 1)

    # Concatenate results from all devices
    refined_cross_correlation = torch.cat(
        [torch.from_numpy(r["refined_cross_correlation"]) for r in results.values()]
    )
    refined_z_score = torch.cat(
        [torch.from_numpy(r["refined_z_score"]) for r in results.values()]
    )
    refined_euler_angles = torch.cat(
        [torch.from_numpy(r["refined_euler_angles"]) for r in results.values()]
    )
    refined_defocus_offset = torch.cat(
        [torch.from_numpy(r["refined_defocus_offset"]) for r in results.values()]
    )
    refined_pixel_size_offset = torch.cat(
        [torch.from_numpy(r["refined_pixel_size_offset"]) for r in results.values()]
    )
    refined_pos_y = torch.cat(
        [torch.from_numpy(r["refined_pos_y"]) for r in results.values()]
    )
    refined_pos_x = torch.cat(
        [torch.from_numpy(r["refined_pos_x"]) for r in results.values()]
    )

    # Ensure the results are sorted back to the original particle order
    # (If particles were split across devices, we need to reorder the results)
    particle_indices = torch.cat(
        [torch.from_numpy(r["particle_indices"]) for r in results.values()]
    )
    angle_idx = torch.cat([torch.from_numpy(r["angle_idx"]) for r in results.values()])
    sort_indices = torch.argsort(particle_indices)

    refined_cross_correlation = refined_cross_correlation[sort_indices]
    refined_z_score = refined_z_score[sort_indices]
    refined_euler_angles = refined_euler_angles[sort_indices]
    refined_defocus_offset = refined_defocus_offset[sort_indices]
    refined_pixel_size_offset = refined_pixel_size_offset[sort_indices]
    refined_pos_y = refined_pos_y[sort_indices]
    refined_pos_x = refined_pos_x[sort_indices]
    angle_idx = angle_idx[sort_indices]
    # Offset refined_pos_{x,y} by the extracted box size (same as original)
    refined_pos_y -= (img_h - template_h + 1) // 2
    refined_pos_x -= (img_w - template_w + 1) // 2

    return {
        "refined_cross_correlation": refined_cross_correlation,
        "refined_z_score": refined_z_score,
        "refined_euler_angles": refined_euler_angles,
        "refined_defocus_offset": refined_defocus_offset,
        "refined_pixel_size_offset": refined_pixel_size_offset,
        "refined_pos_y": refined_pos_y,
        "refined_pos_x": refined_pos_x,
        "angle_idx": angle_idx,
    }

cross_correlate_particle_stack(particle_stack_dft, template_dft, rotation_matrices, projective_filters, mode='valid', batch_size=1024)

Cross-correlate a stack of particle images against a template.

Here, the argument 'particle_stack_dft' is a set of RFFT-ed particle images with necessary filtering already applied. The zeroth dimension corresponds to unique particles.

Parameters:

Name Type Description Default
particle_stack_dft Tensor

The stack of particle real-Fourier transformed and un-fftshifted images. Shape of (N, H, W).

required
template_dft Tensor

The template volume to extract central slices from. Real-Fourier transformed and fftshifted.

required
rotation_matrices Tensor

The orientations of the particles to take the Fourier slices of, as a long list of rotation matrices. Shape of (N, 3, 3).

required
projective_filters Tensor

Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).

required
mode Literal['valid', 'same']

Correlation mode to use, by default "valid". If "valid", the output will be the valid cross-correlation of the inputs. If "same", the output will be the same shape as the input particle stack.

'valid'
batch_size int

The number of particle images to cross-correlate at once. Default is 1024. Larger sizes will consume more memory. If -1, then the entire stack will be cross-correlated at once.

1024

Returns:

Type Description
Tensor

The cross-correlation of the particle stack with the template. Shape will depend on the mode used. If "valid", the output will be (N, H-h+1, W-w+1). If "same", the output will be (N, H, W).

Raises:

Type Description
ValueError

If the mode is not "valid" or "same".

Source code in src/leopard_em/backend/core_refine_template.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def cross_correlate_particle_stack(
    particle_stack_dft: torch.Tensor,  # (N, H, W)
    template_dft: torch.Tensor,  # (d, h, w)
    rotation_matrices: torch.Tensor,  # (N, 3, 3)
    projective_filters: torch.Tensor,  # (N, h, w)
    mode: Literal["valid", "same"] = "valid",
    batch_size: int = 1024,
) -> torch.Tensor:
    """Cross-correlate a stack of particle images against a template.

    Here, the argument 'particle_stack_dft' is a set of RFFT-ed particle images with
    necessary filtering already applied. The zeroth dimension corresponds to unique
    particles.

    Parameters
    ----------
    particle_stack_dft : torch.Tensor
        The stack of particle real-Fourier transformed and un-fftshifted images.
        Shape of (N, H, W).
    template_dft : torch.Tensor
        The template volume to extract central slices from. Real-Fourier transformed
        and fftshifted.
    rotation_matrices : torch.Tensor
        The orientations of the particles to take the Fourier slices of, as a long
        list of rotation matrices. Shape of (N, 3, 3).
    projective_filters : torch.Tensor
        Projective filters to apply to each Fourier slice particle. Shape of (N, h, w).
    mode : Literal["valid", "same"], optional
        Correlation mode to use, by default "valid". If "valid", the output will be
        the valid cross-correlation of the inputs. If "same", the output will be the
        same shape as the input particle stack.
    batch_size : int, optional
        The number of particle images to cross-correlate at once. Default is 1024.
        Larger sizes will consume more memory. If -1, then the entire stack will be
        cross-correlated at once.

    Returns
    -------
    torch.Tensor
        The cross-correlation of the particle stack with the template. Shape will depend
        on the mode used. If "valid", the output will be (N, H-h+1, W-w+1). If "same",
        the output will be (N, H, W).

    Raises
    ------
    ValueError
        If the mode is not "valid" or "same".
    """
    # Helpful constants for later use
    device = particle_stack_dft.device
    num_particles, image_h, image_w = particle_stack_dft.shape
    _, template_h, template_w = template_dft.shape
    # account for RFFT
    image_w = 2 * (image_w - 1)
    template_w = 2 * (template_w - 1)

    if batch_size == -1:
        batch_size = num_particles

    if mode == "valid":
        output_shape = (
            num_particles,
            image_h - template_h + 1,
            image_w - template_w + 1,
        )
    elif mode == "same":
        output_shape = (num_particles, image_h, image_w)
    else:
        raise ValueError(f"Invalid mode: {mode}. Must be 'valid' or 'same'.")

    out_correlation = torch.zeros(output_shape, device=device)

    # Loop over the particle stack in batches
    for i in range(0, num_particles, batch_size):
        batch_particles_dft = particle_stack_dft[i : i + batch_size]
        batch_rotation_matrices = rotation_matrices[i : i + batch_size]
        batch_projective_filters = projective_filters[i : i + batch_size]

        # Extract the Fourier slice and apply the projective filters
        fourier_slice = extract_central_slices_rfft_3d(
            volume_rfft=template_dft,
            image_shape=(template_h,) * 3,
            rotation_matrices=batch_rotation_matrices,
        )
        fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
        fourier_slice[..., 0, 0] = 0 + 0j  # zero out the DC component (mean zero)
        fourier_slice *= -1  # flip contrast
        fourier_slice *= batch_projective_filters

        # Inverse Fourier transform and normalize the projection
        projections = torch.fft.irfftn(fourier_slice, dim=(-2, -1))
        projections = torch.fft.ifftshift(projections, dim=(-2, -1))
        projections = normalize_template_projection(
            projections, (template_h, template_w), (image_h, image_w)
        )

        # Padded forward FFT and cross-correlate
        projections_dft = torch.fft.rfftn(
            projections, dim=(-2, -1), s=(image_h, image_w)
        )
        projections_dft = batch_particles_dft * projections_dft.conj()
        cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))

        # Handle the output shape
        cross_correlation = handle_correlation_mode(
            cross_correlation, output_shape, mode
        )

        out_correlation[i : i + batch_size] = cross_correlation

    return out_correlation