Skip to content

process_results

Functions related to result processing after backend functions.

aggregate_distributed_results(results)

Combine the 2DTM results from multiple devices.

NOTE: This assumes that all tensors have been passed back to the CPU and are in the form of numpy arrays.

Parameters:

Name Type Description Default
results list[dict[str, ndarray]]

List of dictionaries containing the results from each device. Each dictionary contains the following keys: - "mip": Maximum intensity projection of the cross-correlation values. - "best_global_index": Best global search index - "correlation_sum": Sum of cross-correlation values for each pixel. - "correlation_squared_sum": Sum of squared cross-correlation values for each pixel.

required
Source code in src/leopard_em/backend/process_results.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def aggregate_distributed_results(
    results: list[dict[str, torch.Tensor | np.ndarray]],
) -> dict[str, torch.Tensor]:
    """Combine the 2DTM results from multiple devices.

    NOTE: This assumes that all tensors have been passed back to the CPU and are in
    the form of numpy arrays.

    Parameters
    ----------
    results : list[dict[str, np.ndarray]]
        List of dictionaries containing the results from each device. Each dictionary
        contains the following keys:
            - "mip": Maximum intensity projection of the cross-correlation values.
            - "best_global_index": Best global search index
            - "correlation_sum": Sum of cross-correlation values for each pixel.
            - "correlation_squared_sum": Sum of squared cross-correlation values for
              each pixel.
    """
    # Ensure all the tensors are passed back to CPU as numpy arrays
    # Not sure why cannot sync across devices, but this is a workaround
    results = [
        {
            key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
            for key, value in result.items()
        }
        for result in results
    ]

    # Stack results from all devices into a single array. Dim 0 is device index
    mips = np.stack([result["mip"] for result in results], axis=0)
    best_index = np.stack([result["best_global_index"] for result in results], axis=0)

    # Find the maximum MIP across all devices, then decode the best index
    mip_max = mips.max(axis=0)
    mip_argmax = mips.argmax(axis=0)
    best_index = np.take_along_axis(best_index, mip_argmax[None, ...], axis=0)[0]

    # Sum the sums and squared sums of the cross-correlation values
    correlation_sum = np.stack(
        [result["correlation_sum"] for result in results], axis=0
    ).sum(axis=0)
    correlation_squared_sum = np.stack(
        [result["correlation_squared_sum"] for result in results], axis=0
    ).sum(axis=0)

    # Cast back to torch tensors on the CPU
    mip_max = torch.from_numpy(mip_max)
    best_index = torch.from_numpy(best_index)
    correlation_sum = torch.from_numpy(correlation_sum)
    correlation_squared_sum = torch.from_numpy(correlation_squared_sum)

    return {
        "mip": mip_max,
        "best_global_index": best_index,
        "correlation_sum": correlation_sum,
        "correlation_squared_sum": correlation_squared_sum,
    }

correlation_sum_and_squared_sum_to_mean_and_variance(correlation_sum, correlation_squared_sum, total_correlation_positions)

Convert the sum and squared sum of the correlation values to mean and variance.

Parameters:

Name Type Description Default
correlation_sum Tensor

Sum of the correlation values.

required
correlation_squared_sum Tensor

Sum of the squared correlation values.

required
total_correlation_positions int

Total number cross-correlograms calculated.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple containing the mean and variance of the correlation values.

Source code in src/leopard_em/backend/process_results.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def correlation_sum_and_squared_sum_to_mean_and_variance(
    correlation_sum: torch.Tensor,
    correlation_squared_sum: torch.Tensor,
    total_correlation_positions: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert the sum and squared sum of the correlation values to mean and variance.

    Parameters
    ----------
    correlation_sum : torch.Tensor
        Sum of the correlation values.
    correlation_squared_sum : torch.Tensor
        Sum of the squared correlation values.
    total_correlation_positions : int
        Total number cross-correlograms calculated.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        Tuple containing the mean and variance of the correlation values.
    """
    correlation_mean = correlation_sum / total_correlation_positions
    correlation_variance = correlation_squared_sum / total_correlation_positions
    correlation_variance -= correlation_mean**2
    correlation_variance = torch.sqrt(torch.clamp(correlation_variance, min=0))
    return correlation_mean, correlation_variance

decode_global_search_index(global_indices, pixel_values, defocus_values, euler_angles)

Decode flattened global indices back into (cs, defocus, orientation).

Source code in src/leopard_em/backend/process_results.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def decode_global_search_index(
    global_indices: torch.Tensor,  # integer tensor
    pixel_values: torch.Tensor,  # (num_cs,)
    defocus_values: torch.Tensor,  # (num_defocus,)
    euler_angles: torch.Tensor,  # (num_orientations, 3)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Decode flattened global indices back into (cs, defocus, orientation)."""
    _ = pixel_values  # Unused, but possible to add in future

    # num_cs = pixel_values.shape[0]
    num_defocus = defocus_values.shape[0]
    num_orientations = euler_angles.shape[0]

    stride_cs = num_defocus * num_orientations
    stride_defocus = num_orientations

    # Calculate the indexes for each "best" array
    # pixel_idx = global_indices // stride_cs
    rem = global_indices % stride_cs
    defocus_idx = rem // stride_defocus
    orientations_idx = rem % stride_defocus

    phi = euler_angles[orientations_idx, 0]
    theta = euler_angles[orientations_idx, 1]
    psi = euler_angles[orientations_idx, 2]
    defocus = defocus_values[defocus_idx]
    # pixels = pixel_values[pixel_idx]

    return phi, theta, psi, defocus

scale_mip(mip, mip_scaled, correlation_sum, correlation_squared_sum, total_correlation_positions)

Scale the MIP to Z-score map by the mean and variance of the correlation values.

Z-score is accounting for the variation in image intensity and spurious correlations by subtracting the mean and dividing by the standard deviation pixel-wise. Since cross-correlation values are roughly normally distributed for pure noise, Z-score effectively becomes a measure of how unexpected (highly correlated to the reference template) a region is in the image. Note that we are looking at maxima of millions of Gaussian distributions, so Z-score has to be compared with a generalized extreme value distribution (GEV) to determine significance (done elsewhere).

NOTE: This method also updates the correlation_sum and correlation_squared_sum tensors in-place into the mean and variance, respectively. Likely should reflect conversions in variable names...

Parameters:

Name Type Description Default
mip Tensor

MIP of the correlation values.

required
mip_scaled Tensor

Scaled MIP of the correlation values.

required
correlation_sum Tensor

Sum of the correlation values. Updated to mean of the correlation values.

required
correlation_squared_sum Tensor

Sum of the squared correlation values. Updated to variance of the correlation.

required
total_correlation_positions int

Total number cross-correlograms calculated.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple containing, in order, the MIP, scaled MIP, correlation mean, and correlation variance.

Source code in src/leopard_em/backend/process_results.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def scale_mip(
    mip: torch.Tensor,
    mip_scaled: torch.Tensor,
    correlation_sum: torch.Tensor,
    correlation_squared_sum: torch.Tensor,
    total_correlation_positions: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Scale the MIP to Z-score map by the mean and variance of the correlation values.

    Z-score is accounting for the variation in image intensity and spurious correlations
    by subtracting the mean and dividing by the standard deviation pixel-wise. Since
    cross-correlation values are roughly normally distributed for pure noise, Z-score
    effectively becomes a measure of how unexpected (highly correlated to the reference
    template) a region is in the image. Note that we are looking at maxima of millions
    of Gaussian distributions, so Z-score has to be compared with a generalized extreme
    value distribution (GEV) to determine significance (done elsewhere).

    NOTE: This method also updates the correlation_sum and correlation_squared_sum
    tensors in-place into the mean and variance, respectively. Likely should reflect
    conversions in variable names...

    Parameters
    ----------
    mip : torch.Tensor
        MIP of the correlation values.
    mip_scaled : torch.Tensor
        Scaled MIP of the correlation values.
    correlation_sum : torch.Tensor
        Sum of the correlation values. Updated to mean of the correlation values.
    correlation_squared_sum : torch.Tensor
        Sum of the squared correlation values. Updated to variance of the correlation.
    total_correlation_positions : int
        Total number cross-correlograms calculated.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        Tuple containing, in order, the MIP, scaled MIP, correlation mean, and
        correlation variance.
    """
    corr_mean, corr_variance = correlation_sum_and_squared_sum_to_mean_and_variance(
        correlation_sum, correlation_squared_sum, total_correlation_positions
    )

    # Calculate normalized MIP
    mip_scaled = mip - corr_mean
    torch.where(
        corr_variance != 0,  # preventing zero division error, albeit unlikely
        mip_scaled / corr_variance,
        torch.zeros_like(mip_scaled),
        out=mip_scaled,
    )

    # # Update correlation_sum and correlation_squared_sum to mean and variance
    # correlation_sum.copy_(corr_mean)
    # correlation_squared_sum.copy_(corr_variance)

    return mip, mip_scaled, corr_mean, corr_variance