Skip to content

utils

Utility and helper functions associated with the backend of Leopard-EM.

attempt_torch_compilation(target_func, backend='inductor', mode='default')

Compile a function using Torch's compilation utilities.

NOTE: This function will fall back onto the original function if compilation fails or is not supported. Under these circumstances, a warning is issued to inform the user of the failure, but the program will continue to run with the original function.

Parameters:

Name Type Description Default
target_func Callable

The function to compile.

required
backend str

The backend to use for compilation (default is "inductor").

'inductor'
mode str

The mode for compilation (default is "default")

'default'

Returns:

Type Description
Callable

The potentially compiled function.

Warning

If compilation fails, the original function is returned without modification which is useful for program consistency. If compilation is not supported, then a warning is generated, and the original function is returned.

Source code in src/leopard_em/backend/utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def attempt_torch_compilation(
    target_func: F, backend: str = "inductor", mode: str = "default"
) -> F:
    """Compile a function using Torch's compilation utilities.

    NOTE: This function will fall back onto the original function if compilation fails
    or is not supported. Under these circumstances, a warning is issued to inform the
    user of the failure, but the program will continue to run with the original
    function.

    Parameters
    ----------
    target_func : Callable
        The function to compile.
    backend : str, optional
        The backend to use for compilation (default is "inductor").
    mode : str, optional
        The mode for compilation (default is "default")

    Returns
    -------
    Callable
        The potentially compiled function.

    Warning
    -------
    If compilation fails, the original function is returned without modification which
    is useful for program consistency. If compilation is not supported, then a
    warning is generated, and the original function is returned.
    """
    # Check if compilation is disabled via environment variable
    disable_compilation = os.environ.get("LEOPARDEM_DISABLE_TORCH_COMPILATION", "0")
    if disable_compilation != "0":
        return target_func

    try:
        compiled_func = torch.compile(target_func, backend=backend, mode=mode)
        return compiled_func  # type: ignore[no-any-return]
    except (RuntimeError, NotImplementedError) as e:
        warnings.warn(
            f"Failed to compile function {target_func.__name__} with"
            f"backend {backend}: {e}. "
            "Returning the original function instead and continuing...",
            UserWarning,
            stacklevel=2,
        )
        return target_func

do_iteration_statistics_updates(cross_correlation, current_indexes, mip, best_global_index, correlation_sum, correlation_squared_sum, img_h, img_w)

Helper function for updating maxima and tracked statistics.

NOTE: The batch dimensions are effectively unraveled since taking the maximum over a single batch dimensions is much faster than multi-dimensional maxima.

NOTE: Updating the maxima was found to be fastest and least memory impactful when using torch.where directly. Other methods tested were boolean masking and torch.where with tuples of tensor indexes.

Parameters:

Name Type Description Default
cross_correlation Tensor

Cross-correlation values for the current iteration. Has shape (num_cs, num_defocus, num_orientations, H, W) where 'num_cs' are the number of different pixel sizes (controlled by spherical aberration Cs) in the cross-correlation batch, 'num_defocus' are the number of different defocus values in the cross-correlation batch, and 'num_orientations' are the number of different orientations in the cross-correlation batch.

required
current_indexes Tensor

The global search indexes for the current batch of pixel sizes, defocus values, and orientations. Has shape num_cs * num_defocus * num_orientations to uniquely identify the set of pixel sizes, defocus values, and orientations associated with the batch from the global search space.

required
mip Tensor

Maximum intensity projection of the cross-correlation values.

required
best_global_index Tensor

Previous best global search indexes. Has shape (H, W) and is int32 type.

required
correlation_sum Tensor

Sum of cross-correlation values for each pixel.

required
correlation_squared_sum Tensor

Sum of squared cross-correlation values for each pixel.

required
img_h int

Height of the cross-correlation values.

required
img_w int

Width of the cross-correlation values.

required
Source code in src/leopard_em/backend/utils.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def do_iteration_statistics_updates(
    cross_correlation: torch.Tensor,
    current_indexes: torch.Tensor,
    mip: torch.Tensor,
    best_global_index: torch.Tensor,
    correlation_sum: torch.Tensor,
    correlation_squared_sum: torch.Tensor,
    img_h: int,
    img_w: int,
) -> None:
    """Helper function for updating maxima and tracked statistics.

    NOTE: The batch dimensions are effectively unraveled since taking the
    maximum over a single batch dimensions is much faster than
    multi-dimensional maxima.

    NOTE: Updating the maxima was found to be fastest and least memory
    impactful when using torch.where directly. Other methods tested were
    boolean masking and torch.where with tuples of tensor indexes.

    Parameters
    ----------
    cross_correlation : torch.Tensor
        Cross-correlation values for the current iteration. Has shape
        (num_cs, num_defocus, num_orientations, H, W) where 'num_cs' are the number of
        different pixel sizes (controlled by spherical aberration Cs) in the
        cross-correlation batch, 'num_defocus' are the number of different defocus
        values in the cross-correlation batch, and 'num_orientations' are the number of
        different orientations in the cross-correlation batch.
    current_indexes : torch.Tensor
        The global search indexes for the *current* batch of pixel sizes, defocus
        values, and orientations. Has shape `num_cs * num_defocus * num_orientations`
        to uniquely identify the set of pixel sizes, defocus values, and orientations
        associated with the batch from the global search space.
    mip : torch.Tensor
        Maximum intensity projection of the cross-correlation values.
    best_global_index : torch.Tensor
        Previous best global search indexes. Has shape (H, W) and is int32 type.
    correlation_sum : torch.Tensor
        Sum of cross-correlation values for each pixel.
    correlation_squared_sum : torch.Tensor
        Sum of squared cross-correlation values for each pixel.
    img_h : int
        Height of the cross-correlation values.
    img_w : int
        Width of the cross-correlation values.
    """
    cc_reshaped = cross_correlation.view(-1, img_h, img_w)

    # Need two passes for maxima operator for memory efficiency
    # and to distinguish between batch position which would both update
    max_values, max_indices = torch.max(cc_reshaped, dim=0)

    # Do masked updates with torch.where directly (in-place)
    update_mask = max_values > mip
    torch.where(update_mask, max_values, mip, out=mip)
    torch.where(
        update_mask,
        current_indexes[max_indices],
        best_global_index,
        out=best_global_index,
    )

    correlation_sum += cc_reshaped.sum(dim=0)
    correlation_squared_sum += (cc_reshaped**2).sum(dim=0)

normalize_template_projection(projections, small_shape, large_shape)

Subtract mean of edge values and set variance to 1 (in large shape).

This function uses the fact that variance of a sequence, Var(X), is scaled by the relative size of the small (unpadded) and large (padded with zeros) space. Some negligible error is introduced into the variance (~1e-4) due to this routine.

Let $X$ be the large, zero-padded projection and $x$ the small projection each with sizes $(H, W)$ and $(h, w)$, respectively. The mean of the zero-padded projection in terms of the small projection is: .. math:: The variance of the zero-padded projection in terms of the small projection can be obtained by: .. math::

Parameters:

Name Type Description Default
projections Tensor

Real-space projections of the template (in small space).

required
small_shape tuple[int, int]

Shape of the template.

required
large_shape tuple[int, int]

Shape of the image (in large space).

required

Returns:

Type Description
Tensor

Edge-mean subtracted projections, still in small space, but normalized so variance of zero-padded projection is 1.

Source code in src/leopard_em/backend/utils.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def normalize_template_projection(
    projections: torch.Tensor,  # shape (batch, h, w)
    small_shape: tuple[int, int],  # (h, w)
    large_shape: tuple[int, int],  # (H, W)
) -> torch.Tensor:
    r"""Subtract mean of edge values and set variance to 1 (in large shape).

    This function uses the fact that variance of a sequence, Var(X), is scaled by the
    relative size of the small (unpadded) and large (padded with zeros) space. Some
    negligible error is introduced into the variance (~1e-4) due to this routine.

    Let $X$ be the large, zero-padded projection and $x$ the small projection each
    with sizes $(H, W)$ and $(h, w)$, respectively. The mean of the zero-padded
    projection in terms of the small projection is:
    .. math::
        \begin{align}
            \mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} X_{ij} \\
            \mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} X_{ij} + 0 \\
            \mu(X) &= \frac{h \cdot w}{H \cdot W} \mu(x)
        \end{align}
    The variance of the zero-padded projection in terms of the small projection can be
    obtained by:
    .. math::
        \begin{align}
            Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} (X_{ij} -
                \mu(X))^2 \\
            Var(X) &= \frac{1}{H \cdot W} \left(\sum_{i=1}^{h}
                \sum_{j=1}^{w} (X_{ij} - \mu(X))^2 +
                \sum_{i=h+1}^{H}\sum_{i=w+1}^{W} \mu(X)^2 \right) \\
            Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} (X_{ij} -
                \mu(X))^2 + (H-h)(W-w)\mu(X)^2
        \end{align}

    Parameters
    ----------
    projections : torch.Tensor
        Real-space projections of the template (in small space).
    small_shape : tuple[int, int]
        Shape of the template.
    large_shape : tuple[int, int]
        Shape of the image (in large space).

    Returns
    -------
    torch.Tensor
        Edge-mean subtracted projections, still in small space, but normalized
        so variance of zero-padded projection is 1.
    """
    # Extract edges while preserving batch dimensions
    top_edge = projections[..., 0, :]  # shape: (..., w)
    bottom_edge = projections[..., -1, :]  # shape: (..., w)
    left_edge = projections[..., 1:-1, 0]  # shape: (..., h-2)
    right_edge = projections[..., 1:-1, -1]  # shape: (..., h-2)
    edge_pixels = torch.concatenate(
        [top_edge, bottom_edge, left_edge, right_edge], dim=-1
    )

    # Subtract the edge pixel mean and calculate variance of small, unpadded projection
    projections = projections - edge_pixels.mean(dim=-1)[..., None, None]

    # # Calculate variance like cisTEM (does not match desired results...)
    # variance = (projections**2).sum(dim=(-1, -2), keepdim=True) * relative_size - (
    #     projections.mean(dim=(-1, -2), keepdim=True) * relative_size
    # ) ** 2

    # Fast calculation of mean/var using Torch + appropriate scaling.
    large_size_sqrt = (large_shape[0] * large_shape[1]) ** 0.5
    relative_size = (small_shape[0] * small_shape[1]) / (
        large_shape[0] * large_shape[1]
    )

    mean = torch.mean(projections, dim=(-2, -1), keepdim=True) * relative_size
    mean = mean * relative_size

    # First term of the variance calculation
    variance = torch.sum((projections - mean) ** 2, dim=(-2, -1), keepdim=True)
    # Add the second term of the variance calculation
    variance = variance + (
        (large_shape[0] - small_shape[0]) * (large_shape[1] - small_shape[1]) * mean**2
    )

    projections = (projections * large_size_sqrt) / torch.sqrt(variance.clamp_min(1e-8))
    return projections