Skip to content

distributed

Utilities related to distributed computing for the backend functions.

DistributedTCPIndexQueue

Bases: WorkIndexQueue

Distributed work index queue backed by torch.distributed.TCPStore.

Drop-in replacement for MultiprocessWorkIndexQueue but for multi-node setups.

Parameters:

Name Type Description Default
store TCPStore

A torch.distributed.TCPStore object for managing shared state. Must be already initialized and reachable by all processes.

required
total_indices int

The total number of indices (work items) to be processed. Each index is considered its own work item, and these items will generally batched together.

required
batch_size int

The number of indices to be processed in each batch.

required
num_processes int

The total number of processes grabbing work from this queue. Used as a way to track how fast each process is grabbing work from the queue

required
prefetch_size int

The number of indices to prefetch for processing. Is a multiplicitive factor for batch_size. For example, if batch_size is 10 and prefetch_size is 3, then up to 30 indices will be prefetched for processing.

10
counter_key str

The key in the TCPStore for the shared next index counter.

'next_index'
error_key str

The key in the TCPStore for the shared error flag.

'error_flag'
process_counts_prefix str

The prefix for keys in the TCPStore for the per-process claimed counts.

'process_count_'
Source code in src/leopard_em/backend/distributed.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class DistributedTCPIndexQueue(WorkIndexQueue):
    """Distributed work index queue backed by torch.distributed.TCPStore.

    Drop-in replacement for MultiprocessWorkIndexQueue but for multi-node setups.

    Parameters
    ----------
    store : dist.TCPStore
        A torch.distributed.TCPStore object for managing shared state. Must be already
        initialized and reachable by all processes.
    total_indices : int
        The total number of indices (work items) to be processed. Each index is
        considered its own work item, and these items will generally batched together.
    batch_size : int
        The number of indices to be processed in each batch.
    num_processes : int
        The total number of processes grabbing work from this queue. Used as a way
        to track how fast each process is grabbing work from the queue
    prefetch_size : int
        The number of indices to prefetch for processing. Is a multiplicitive factor
        for batch_size. For example, if batch_size is 10 and prefetch_size is 3, then
        up to 30 indices will be prefetched for processing.
    counter_key : str
        The key in the TCPStore for the shared next index counter.
    error_key : str
        The key in the TCPStore for the shared error flag.
    process_counts_prefix : str
        The prefix for keys in the TCPStore for the per-process claimed counts.
    """

    store: dist.TCPStore
    total_indices: int
    batch_size: int
    num_processes: int
    prefetch_size: int
    counter_key: str
    error_key: str
    process_counts_prefix: str

    def __init__(
        self,
        store: dist.TCPStore,
        total_indices: int,
        batch_size: int,
        num_processes: int,
        prefetch_size: int = 10,
        counter_key: str = "next_index",
        error_key: str = "error_flag",
        process_counts_prefix: str = "process_count_",
    ):
        super().__init__(total_indices, batch_size, num_processes, prefetch_size)

        self.store = store
        self.counter_key = counter_key
        self.error_key = error_key
        self.process_counts_prefix = process_counts_prefix

    @staticmethod
    def initialize_store(
        store: dist.TCPStore,
        rank: int,
        num_processes: int,
        counter_key: str = "next_index",
        error_key: str = "error_flag",
        process_counts_prefix: str = "process_count_",
    ) -> None:
        """Have rank 0 initialize the shared keys in the store.

        NOTE: Includes a synchronization barrier so MUST be called by all processes.
        """
        if rank == 0:
            # set keys unconditionally on the server to avoid compare_set races
            store.set(counter_key, "0")
            store.set(error_key, "0")
            for pid in range(num_processes):
                store.set(f"{process_counts_prefix}{pid}", "0")
        # synchronize so other ranks can safely call store.get()/add()
        dist.barrier()

    def get_next_indices(
        self, process_id: Optional[int] = None
    ) -> Optional[tuple[int, int]]:
        """Atomically claim the next chunk of indices for a process."""
        delta = self.batch_size * self.prefetch_size

        # fetch-and-add returns the *new* value after increment
        new_val = self.store.add(self.counter_key, delta)
        end_idx = int(new_val)
        start_idx = end_idx - delta

        if start_idx >= self.total_indices:
            return None

        end_idx = min(end_idx, self.total_indices)

        claimed = end_idx - start_idx
        if process_id is not None and claimed > 0:
            self.store.add(f"{self.process_counts_prefix}{process_id}", claimed)

        if claimed <= 0:
            return None
        return (start_idx, end_idx)

    def get_current_index(self) -> int:
        """Get the current progress of the queue."""
        return int(self.store.get(self.counter_key).decode("utf-8"))

    def get_process_counts(self) -> list[int]:
        """Get per-process claimed counts."""
        counts = []
        for pid in range(self.num_processes):
            v = int(
                self.store.get(f"{self.process_counts_prefix}{pid}").decode("utf-8")
            )
            counts.append(v)
        return counts

    def error_occurred(self) -> bool:
        """Check if an error has occurred."""
        return bool(self.store.get(self.error_key).decode("utf-8") == "1")

    def set_error_flag(self) -> None:
        """Set the error flag."""
        self.store.set(self.error_key, "1")

error_occurred()

Check if an error has occurred.

Source code in src/leopard_em/backend/distributed.py
280
281
282
def error_occurred(self) -> bool:
    """Check if an error has occurred."""
    return bool(self.store.get(self.error_key).decode("utf-8") == "1")

get_current_index()

Get the current progress of the queue.

Source code in src/leopard_em/backend/distributed.py
266
267
268
def get_current_index(self) -> int:
    """Get the current progress of the queue."""
    return int(self.store.get(self.counter_key).decode("utf-8"))

get_next_indices(process_id=None)

Atomically claim the next chunk of indices for a process.

Source code in src/leopard_em/backend/distributed.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def get_next_indices(
    self, process_id: Optional[int] = None
) -> Optional[tuple[int, int]]:
    """Atomically claim the next chunk of indices for a process."""
    delta = self.batch_size * self.prefetch_size

    # fetch-and-add returns the *new* value after increment
    new_val = self.store.add(self.counter_key, delta)
    end_idx = int(new_val)
    start_idx = end_idx - delta

    if start_idx >= self.total_indices:
        return None

    end_idx = min(end_idx, self.total_indices)

    claimed = end_idx - start_idx
    if process_id is not None and claimed > 0:
        self.store.add(f"{self.process_counts_prefix}{process_id}", claimed)

    if claimed <= 0:
        return None
    return (start_idx, end_idx)

get_process_counts()

Get per-process claimed counts.

Source code in src/leopard_em/backend/distributed.py
270
271
272
273
274
275
276
277
278
def get_process_counts(self) -> list[int]:
    """Get per-process claimed counts."""
    counts = []
    for pid in range(self.num_processes):
        v = int(
            self.store.get(f"{self.process_counts_prefix}{pid}").decode("utf-8")
        )
        counts.append(v)
    return counts

initialize_store(store, rank, num_processes, counter_key='next_index', error_key='error_flag', process_counts_prefix='process_count_') staticmethod

Have rank 0 initialize the shared keys in the store.

NOTE: Includes a synchronization barrier so MUST be called by all processes.

Source code in src/leopard_em/backend/distributed.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@staticmethod
def initialize_store(
    store: dist.TCPStore,
    rank: int,
    num_processes: int,
    counter_key: str = "next_index",
    error_key: str = "error_flag",
    process_counts_prefix: str = "process_count_",
) -> None:
    """Have rank 0 initialize the shared keys in the store.

    NOTE: Includes a synchronization barrier so MUST be called by all processes.
    """
    if rank == 0:
        # set keys unconditionally on the server to avoid compare_set races
        store.set(counter_key, "0")
        store.set(error_key, "0")
        for pid in range(num_processes):
            store.set(f"{process_counts_prefix}{pid}", "0")
    # synchronize so other ranks can safely call store.get()/add()
    dist.barrier()

set_error_flag()

Set the error flag.

Source code in src/leopard_em/backend/distributed.py
284
285
286
def set_error_flag(self) -> None:
    """Set the error flag."""
    self.store.set(self.error_key, "1")

MultiprocessWorkIndexQueue

Bases: WorkIndexQueue

Single-node (distributed memory) multiprocessing work index queue.

Uses multiprocessing primitives for shared state management within a single machine by using shared memory.

Parameters:

Name Type Description Default
total_indices int

The total number of indices (work items) to be processed.

required
batch_size int

The number of indices to be processed in each batch.

required
num_processes int

The total number of processes grabbing work from this queue.

required
prefetch_size int

The number of indices to prefetch for processing, by default 10.

10
Source code in src/leopard_em/backend/distributed.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class MultiprocessWorkIndexQueue(WorkIndexQueue):
    """Single-node (distributed memory) multiprocessing work index queue.

    Uses multiprocessing primitives for shared state management within a single machine
    by using shared memory.

    Parameters
    ----------
    total_indices : int
        The total number of indices (work items) to be processed.
    batch_size : int
        The number of indices to be processed in each batch.
    num_processes : int
        The total number of processes grabbing work from this queue.
    prefetch_size : int, optional
        The number of indices to prefetch for processing, by default 10.
    """

    def __init__(
        self,
        total_indices: int,
        batch_size: int,
        num_processes: int,
        prefetch_size: int = 10,
    ):
        super().__init__(total_indices, batch_size, num_processes, prefetch_size)
        self.next_index = mp.Value("i", 0)  # Shared counter
        self.process_counts = mp.Array("i", [0] * num_processes)
        self.error_flag = mp.Value("i", 0)  # 0 = no error, 1 = error occurred
        self.lock = mp.Lock()

    def get_next_indices(
        self, process_id: Optional[int] = None
    ) -> Optional[tuple[int, int]]:
        """Get the next set of indices to process returning None if all work is done.

        Parameters
        ----------
        process_id: Optional[int]
            Optional process index to use for updating the 'process_counts' array.
            Default is None which corresponds to no update.
        """
        with self.lock:
            start_idx = self.next_index.value
            if start_idx >= self.total_indices:
                return None

            # Do not go past total_indices
            end_idx = min(
                start_idx + self.batch_size * self.prefetch_size, self.total_indices
            )
            self.next_index.value = end_idx

            # Update the per-process counter
            if process_id is not None:
                self.process_counts[process_id] += end_idx - start_idx

            return (start_idx, end_idx)

    def get_current_index(self) -> int:
        """Get the current progress of the work queue (as an integer)."""
        with self.lock:
            return int(self.next_index.value)

    def get_process_counts(self) -> list[int]:
        """Get the number of indexes of work processed by each process."""
        with self.lock:
            return list(self.process_counts)

    def error_occurred(self) -> bool:
        """Check if an error has occurred in any process."""
        with self.lock:
            return bool(self.error_flag.value == 1)

    def set_error_flag(self) -> None:
        """Set the error flag to indicate an error has occurred."""
        with self.lock:
            self.error_flag.value = 1

error_occurred()

Check if an error has occurred in any process.

Source code in src/leopard_em/backend/distributed.py
152
153
154
155
def error_occurred(self) -> bool:
    """Check if an error has occurred in any process."""
    with self.lock:
        return bool(self.error_flag.value == 1)

get_current_index()

Get the current progress of the work queue (as an integer).

Source code in src/leopard_em/backend/distributed.py
142
143
144
145
def get_current_index(self) -> int:
    """Get the current progress of the work queue (as an integer)."""
    with self.lock:
        return int(self.next_index.value)

get_next_indices(process_id=None)

Get the next set of indices to process returning None if all work is done.

Parameters:

Name Type Description Default
process_id Optional[int]

Optional process index to use for updating the 'process_counts' array. Default is None which corresponds to no update.

None
Source code in src/leopard_em/backend/distributed.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_next_indices(
    self, process_id: Optional[int] = None
) -> Optional[tuple[int, int]]:
    """Get the next set of indices to process returning None if all work is done.

    Parameters
    ----------
    process_id: Optional[int]
        Optional process index to use for updating the 'process_counts' array.
        Default is None which corresponds to no update.
    """
    with self.lock:
        start_idx = self.next_index.value
        if start_idx >= self.total_indices:
            return None

        # Do not go past total_indices
        end_idx = min(
            start_idx + self.batch_size * self.prefetch_size, self.total_indices
        )
        self.next_index.value = end_idx

        # Update the per-process counter
        if process_id is not None:
            self.process_counts[process_id] += end_idx - start_idx

        return (start_idx, end_idx)

get_process_counts()

Get the number of indexes of work processed by each process.

Source code in src/leopard_em/backend/distributed.py
147
148
149
150
def get_process_counts(self) -> list[int]:
    """Get the number of indexes of work processed by each process."""
    with self.lock:
        return list(self.process_counts)

set_error_flag()

Set the error flag to indicate an error has occurred.

Source code in src/leopard_em/backend/distributed.py
157
158
159
160
def set_error_flag(self) -> None:
    """Set the error flag to indicate an error has occurred."""
    with self.lock:
        self.error_flag.value = 1

TensorShapeDataclass dataclass

Helper class for sending expected tensor shapes to distributed processes.

Source code in src/leopard_em/backend/distributed.py
289
290
291
292
293
294
295
296
297
298
299
@dataclass
class TensorShapeDataclass:
    """Helper class for sending expected tensor shapes to distributed processes."""

    image_dft_shape: tuple[int, int]  # (H, W // 2 + 1)
    template_dft_shape: tuple[int, int, int]  # (l, h, w // 2 + 1)
    ctf_filters_shape: tuple[int, int, int, int]  # (num_Cs, num_defocus, h, w // 2 + 1)
    whitening_filter_template_shape: tuple[int, int]  # (h, w // 2 + 1)
    euler_angles_shape: tuple[int, int]  # (num_orientations, 3)
    defocus_values_shape: tuple[int]  # (num_defocus,)
    pixel_values_shape: tuple[int]  # (num_Cs,)

WorkIndexQueue

Bases: ABC

Abstract base class for index queues that manage distributed work allocation.

This class defines the common interface for both single-node multiprocessing and multi-node distributed computing scenarios.

Parameters:

Name Type Description Default
total_indices int

The total number of indices (work items) to be processed.

required
batch_size int

The number of indices to be processed in each batch.

required
num_processes int

The total number of processes grabbing work from this queue.

required
prefetch_size int

The number of indices to prefetch for processing (multiplicative factor for batch_size).

10
Source code in src/leopard_em/backend/distributed.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class WorkIndexQueue(ABC):
    """Abstract base class for index queues that manage distributed work allocation.

    This class defines the common interface for both single-node multiprocessing
    and multi-node distributed computing scenarios.

    Parameters
    ----------
    total_indices : int
        The total number of indices (work items) to be processed.
    batch_size : int
        The number of indices to be processed in each batch.
    num_processes : int
        The total number of processes grabbing work from this queue.
    prefetch_size : int
        The number of indices to prefetch for processing (multiplicative factor
        for batch_size).
    """

    def __init__(
        self,
        total_indices: int,
        batch_size: int,
        num_processes: int,
        prefetch_size: int = 10,
    ):
        self.total_indices = total_indices
        self.batch_size = batch_size
        self.num_processes = num_processes
        self.prefetch_size = prefetch_size

    @abstractmethod
    def get_next_indices(
        self, process_id: Optional[int] = None
    ) -> Optional[tuple[int, int]]:
        """Get the next set of indices to process, returning None if all work is done.

        Parameters
        ----------
        process_id : Optional[int]
            Optional process index for updating per-process counters.

        Returns
        -------
        Optional[tuple[int, int]]
            Tuple of (start_idx, end_idx) or None if no work remains.
        """
        raise NotImplementedError

    @abstractmethod
    def get_current_index(self) -> int:
        """Get the current progress of the work queue."""
        raise NotImplementedError

    @abstractmethod
    def get_process_counts(self) -> list[int]:
        """Get per-process work counts."""
        raise NotImplementedError

    @abstractmethod
    def error_occurred(self) -> bool:
        """Check if an error has occurred in any process."""
        raise NotImplementedError

    @abstractmethod
    def set_error_flag(self) -> None:
        """Set the error flag to indicate an error has occurred."""
        raise NotImplementedError

error_occurred() abstractmethod

Check if an error has occurred in any process.

Source code in src/leopard_em/backend/distributed.py
71
72
73
74
@abstractmethod
def error_occurred(self) -> bool:
    """Check if an error has occurred in any process."""
    raise NotImplementedError

get_current_index() abstractmethod

Get the current progress of the work queue.

Source code in src/leopard_em/backend/distributed.py
61
62
63
64
@abstractmethod
def get_current_index(self) -> int:
    """Get the current progress of the work queue."""
    raise NotImplementedError

get_next_indices(process_id=None) abstractmethod

Get the next set of indices to process, returning None if all work is done.

Parameters:

Name Type Description Default
process_id Optional[int]

Optional process index for updating per-process counters.

None

Returns:

Type Description
Optional[tuple[int, int]]

Tuple of (start_idx, end_idx) or None if no work remains.

Source code in src/leopard_em/backend/distributed.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@abstractmethod
def get_next_indices(
    self, process_id: Optional[int] = None
) -> Optional[tuple[int, int]]:
    """Get the next set of indices to process, returning None if all work is done.

    Parameters
    ----------
    process_id : Optional[int]
        Optional process index for updating per-process counters.

    Returns
    -------
    Optional[tuple[int, int]]
        Tuple of (start_idx, end_idx) or None if no work remains.
    """
    raise NotImplementedError

get_process_counts() abstractmethod

Get per-process work counts.

Source code in src/leopard_em/backend/distributed.py
66
67
68
69
@abstractmethod
def get_process_counts(self) -> list[int]:
    """Get per-process work counts."""
    raise NotImplementedError

set_error_flag() abstractmethod

Set the error flag to indicate an error has occurred.

Source code in src/leopard_em/backend/distributed.py
76
77
78
79
@abstractmethod
def set_error_flag(self) -> None:
    """Set the error flag to indicate an error has occurred."""
    raise NotImplementedError

run_multiprocess_jobs(target, kwargs_list, extra_args=(), extra_kwargs=None, post_start_callback=None, ranks=None)

Helper function for running multiple processes on the same target function.

Spawns multiple processes to run the same target function with different keyword arguments, aggregates results in a shared dictionary, and returns them.

Parameters:

Name Type Description Default
target Callable

The function that each process will execute. It must accept at least two positional arguments: a shared dict and a unique index.

required
kwargs_list list[dict[str, Any]]

A list of dictionaries containing keyword arguments for each process.

required
extra_args tuple[Any, ...]

Additional positional arguments to pass to the target (prepending the shared parameters).

()
extra_kwargs Optional[dict[str, Any]]

Additional common keyword arguments for all processes.

None
post_start_callback Optional[Callable]

Callback function to call after all processes have been started.

None
ranks list[int]

If not None, then pass these integers as the ranks to the processes. Otherwise, pass the the index of the kwargs_list (default). Must be the same length as kwargs_list.

None

Returns:

Type Description
dict[Any, Any]

Aggregated results stored in the shared dictionary.

Raises:

Type Description
RuntimeError

If any child process encounters an error.

ValueError

If ranks is not None and its length does not match kwargs_list.

Example
def worker_fn(result_dict, idx, param1, param2):
    result_dict[idx] = param1 + param2


kwargs_per_process = [
    {"param1": 1, "param2": 2},
    {"param1": 3, "param2": 4},
]
results = run_multiprocess_jobs(worker_fn, kwargs_per_process)
print(results)
# {0: 3, 1: 7}
Source code in src/leopard_em/backend/distributed.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def run_multiprocess_jobs(
    target: Callable,
    kwargs_list: list[dict[str, Any]],
    extra_args: tuple[Any, ...] = (),
    extra_kwargs: Optional[dict[str, Any]] = None,
    post_start_callback: Optional[Callable] = None,
    ranks: Optional[list[int]] = None,
) -> dict[Any, Any]:
    """Helper function for running multiple processes on the same target function.

    Spawns multiple processes to run the same target function with different keyword
    arguments, aggregates results in a shared dictionary, and returns them.

    Parameters
    ----------
    target : Callable
        The function that each process will execute. It must accept at least two
        positional arguments: a shared dict and a unique index.
    kwargs_list : list[dict[str, Any]]
        A list of dictionaries containing keyword arguments for each process.
    extra_args : tuple[Any, ...], optional
        Additional positional arguments to pass to the target (prepending the shared
        parameters).
    extra_kwargs : Optional[dict[str, Any]], optional
        Additional common keyword arguments for all processes.
    post_start_callback : Optional[Callable], optional
        Callback function to call after all processes have been started.
    ranks : list[int], optional
        If not None, then pass these integers as the ranks to the processes. Otherwise,
        pass the the index of the kwargs_list (default). Must be the same length as
        kwargs_list.

    Returns
    -------
    dict[Any, Any]
        Aggregated results stored in the shared dictionary.

    Raises
    ------
    RuntimeError
        If any child process encounters an error.
    ValueError
        If ranks is not None and its length does not match kwargs_list.

    Example
    -------
    ```
    def worker_fn(result_dict, idx, param1, param2):
        result_dict[idx] = param1 + param2


    kwargs_per_process = [
        {"param1": 1, "param2": 2},
        {"param1": 3, "param2": 4},
    ]
    results = run_multiprocess_jobs(worker_fn, kwargs_per_process)
    print(results)
    # {0: 3, 1: 7}
    ```
    """
    if ranks is not None:
        if len(kwargs_list) != len(ranks):
            raise ValueError("Length of ranks must match length of kwargs_list.")
    else:
        ranks = list(range(len(kwargs_list)))

    if extra_kwargs is None:
        extra_kwargs = {}

    # Manager object for shared result data as a dictionary
    manager = Manager()
    result_dict = manager.dict()
    processes: list[Process] = []

    for rank, kwargs in zip(ranks, kwargs_list):
        args = (*extra_args, result_dict, rank)

        # Merge per-process kwargs with common kwargs.
        proc_kwargs = {**extra_kwargs, **kwargs}
        p = Process(target=target, args=args, kwargs=proc_kwargs)
        processes.append(p)
        p.start()

    if post_start_callback is not None:
        post_start_callback()

    for p in processes:
        p.join()

    return dict(result_dict)