Skip to content

core_match_template_distributed

Distributed multi-node version of the core match_template implementation.

core_match_template_distributed(world_size, rank, local_rank, device, orientation_batch_size=1, num_cuda_streams=1, backend='streamed', **kwargs)

Distributed multi-node core function for the match template program.

Parameters:

Name Type Description Default
world_size int

Total number of processes in the distributed job.

required
rank int

Global rank of this process.

required
local_rank int

Local rank of this process on the current node.

required
device device

The CUDA device to use for this process. This must be a single device.

required
orientation_batch_size int

Number of orientations to process in a single batch, by default 1.

1
num_cuda_streams int

Number of CUDA streams to use for overlapping data transfers and computation, by default 1.

1
backend str

The backend to use for computation. Defaults to 'streamed'. Must be 'streamed' or 'batched'.

'streamed'
**kwargs dict[str, Tensor]

Additional keyword arguments passed to the single-GPU core function. For the zeroth rank this should be a dictionary of Tensor objects with the following fields (all other ranks can pass an empty dictionary): - image_dft: Real-fourier transform (RFFT) of the image with large image filters already applied. Has shape (H, W // 2 + 1). - template_dft: Real-fourier transform (RFFT) of the template volume to take Fourier slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the half-dimension for real-FFT transformation. NOTE: The original template volume should be a cubic volume, i.e. h == w == l. - ctf_filters: Stack of CTF filters at different pixel size (Cs) and defocus values to use in the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs are the number of pixel sizes searched over, and num_defocus are the number of defocus values searched over. - whitening_filter_template: Precomputed whitening filter for the template. Whitening filter for the template volume. Has shape (h, w // 2 + 1). Gets multiplied with the ctf filters to create a filter stack applied to each orientation projection. - euler_angles: Euler angles (in 'ZYZ' convention & in units of degrees) to search over. Has shape (num_orientations, 3). - defocus_values: 1D tensor of defocus values to search. What defoucs values correspond with the CTF filters, in units of Angstroms. Has shape (num_defocus,). - pixel_values: 1D tensor of pixel values to search. What pixel size values correspond with the CTF filters, in units of Angstroms. Has shape (num_Cs,).

{}
Source code in src/leopard_em/backend/core_match_template_distributed.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def core_match_template_distributed(
    world_size: int,
    rank: int,
    local_rank: int,
    device: torch.device,
    orientation_batch_size: int = 1,
    num_cuda_streams: int = 1,
    backend: str = "streamed",
    **kwargs: dict,
) -> dict[str, torch.Tensor]:
    """Distributed multi-node core function for the match template program.

    Parameters
    ----------
    world_size : int
        Total number of processes in the distributed job.
    rank : int
        Global rank of this process.
    local_rank : int
        Local rank of this process on the current node.
    device : torch.device
        The CUDA device to use for this process. This *must* be a single device.
    orientation_batch_size : int, optional
        Number of orientations to process in a single batch, by default 1.
    num_cuda_streams : int, optional
        Number of CUDA streams to use for overlapping data transfers and
        computation, by default 1.
    backend : str, optional
        The backend to use for computation. Defaults to 'streamed'.
        Must be 'streamed' or 'batched'.
    **kwargs : dict[str, torch.Tensor]
        Additional keyword arguments passed to the single-GPU core function. For the
        zeroth rank this should be a dictionary of Tensor objects with the following
        fields (all other ranks can pass an empty dictionary):
        - image_dft:
            Real-fourier transform (RFFT) of the image with large image filters
            already applied. Has shape (H, W // 2 + 1).
        - template_dft:
            Real-fourier transform (RFFT) of the template volume to take Fourier
            slices from. Has shape (l, h, w // 2 + 1) with the last dimension being the
            half-dimension for real-FFT transformation. NOTE: The original template
            volume should be a cubic volume, i.e. h == w == l.
        - ctf_filters:
            Stack of CTF filters at different pixel size (Cs) and  defocus values to use
            in the search. Has shape (num_Cs, num_defocus, h, w // 2 + 1) where num_Cs
            are the number of pixel sizes searched over, and num_defocus are the number
            of defocus values searched over.
        - whitening_filter_template: Precomputed whitening filter for the template.
            Whitening filter for the template volume. Has shape (h, w // 2 + 1).
            Gets multiplied with the ctf filters to create a filter stack applied to
            each orientation projection.
        - euler_angles:
            Euler angles (in 'ZYZ' convention & in units of degrees) to search over. Has
            shape (num_orientations, 3).
        - defocus_values: 1D tensor of defocus values to search.
            What defoucs values correspond with the CTF filters, in units of Angstroms.
            Has shape (num_defocus,).
        - pixel_values: 1D tensor of pixel values to search.
            What pixel size values correspond with the CTF filters, in units of
            Angstroms. Has shape (num_Cs,).
    """
    # Check proper distributed initialization and CUDA device
    _check_distributed_and_device(rank, device)
    _ = local_rank

    torch.cuda.set_device(device)

    # Extract (only on rank zero) and broadcast tensor data to all ranks
    (
        image_dft,
        template_dft,
        ctf_filters,
        whitening_filter_template,
        euler_angles,
        defocus_values,
        pixel_values,
    ) = _extract_and_broadcast_tensors(device, rank, kwargs)

    ##############################################################
    ### Pre-multiply the whitening filter with the CTF filters ###
    ##############################################################

    projective_filters = ctf_filters * whitening_filter_template[None, None, ...]
    total_projections = (
        euler_angles.shape[0] * defocus_values.shape[0] * pixel_values.shape[0]
    )

    ########################################################
    ### TCP Setup for distributed index queue management ###
    ########################################################

    distributed_queue = _setup_distributed_queue(
        world_size=world_size,
        rank=rank,
        orientation_batch_size=orientation_batch_size,
        total_indices=euler_angles.shape[0],
    )

    ###########################################################
    ### Calling the single GPU core match template function ###
    ###########################################################

    dist.barrier()
    (mip, best_global_index, correlation_sum, correlation_squared_sum) = (
        _core_match_template_single_gpu(
            rank=rank,
            index_queue=distributed_queue,  # type: ignore
            image_dft=image_dft,
            template_dft=template_dft,
            euler_angles=euler_angles,
            projective_filters=projective_filters,
            defocus_values=defocus_values,
            pixel_values=pixel_values,
            orientation_batch_size=orientation_batch_size,
            num_cuda_streams=num_cuda_streams,
            backend=backend,
            device=device,
        )
    )
    dist.barrier()

    # Gather all tensors to rank zero GPU
    (
        gather_mip,
        gather_best_global_index,
        gather_correlation_sum,
        gather_correlation_squared_sum,
    ) = _gather_tensors_to_rank_zero(
        world_size=world_size,
        rank=rank,
        mip=mip,
        best_global_index=best_global_index,
        correlation_sum=correlation_sum,
        correlation_squared_sum=correlation_squared_sum,
    )

    ##################################################
    ### Final aggregation step on the main process ###
    ##################################################

    if rank != 0:
        return {}

    # Continue on the main process only
    assert gather_mip is not None
    assert gather_best_global_index is not None
    assert gather_correlation_sum is not None
    assert gather_correlation_squared_sum is not None

    aggregated_results = aggregate_distributed_results(
        results=[
            {
                "mip": mip,
                "best_global_index": gidx,
                "correlation_sum": corr_sum,
                "correlation_squared_sum": corr_sq_sum,
            }
            for mip, gidx, corr_sum, corr_sq_sum in zip(
                gather_mip,
                gather_best_global_index,
                gather_correlation_sum,
                gather_correlation_squared_sum,
            )
        ]
    )
    mip = aggregated_results["mip"]
    best_global_index = aggregated_results["best_global_index"]
    correlation_sum = aggregated_results["correlation_sum"]
    correlation_squared_sum = aggregated_results["correlation_squared_sum"]

    # Ensuring all tensors are now on the CPU device:
    # fmt: off
    mip                     = mip.cpu()
    best_global_index       = best_global_index.cpu()
    correlation_sum         = correlation_sum.cpu()
    correlation_squared_sum = correlation_squared_sum.cpu()
    pixel_values            = pixel_values.cpu()
    defocus_values          = defocus_values.cpu()
    euler_angles            = euler_angles.cpu()
    # fmt: on

    # Map from global search index to the best defocus & angles
    # pylint: disable=duplicate-code
    best_phi, best_theta, best_psi, best_defocus = decode_global_search_index(
        best_global_index, pixel_values, defocus_values, euler_angles
    )

    mip_scaled = torch.empty_like(mip)
    mip, mip_scaled, correlation_mean, correlation_variance = scale_mip(
        mip=mip,
        mip_scaled=mip_scaled,
        correlation_sum=correlation_sum,
        correlation_squared_sum=correlation_squared_sum,
        total_correlation_positions=total_projections,
    )

    return {
        "mip": mip.cpu(),
        "scaled_mip": mip_scaled.cpu(),
        "best_phi": best_phi.cpu(),
        "best_theta": best_theta.cpu(),
        "best_psi": best_psi.cpu(),
        "best_defocus": best_defocus.cpu(),
        "correlation_mean": correlation_mean.cpu(),
        "correlation_variance": correlation_variance.cpu(),
        "total_projections": total_projections,
        "total_orientations": euler_angles.shape[0],
        "total_defocus": defocus_values.shape[0],
    }