Skip to content

particle_stack

Particle stack Pydantic model for dealing with extracted particle data.

ParticleStack

Bases: BaseModel2DTM

Pydantic model for dealing with particle stack data.

Attributes:

Name Type Description
df_path str

Path to the DataFrame containing the particle data. The DataFrame must have the following columns (see the documentation for further information):

  • mip
  • scaled_mip
  • correlation_mean
  • correlation_variance
  • total_correlations
  • pos_x
  • pos_y
  • pos_x_img
  • pos_y_img
  • pos_x_img_angstrom
  • pos_y_img_angstrom
  • psi
  • theta
  • phi
  • relative_defocus
  • refined_relative_defocus
  • defocus_u
  • defocus_v
  • astigmatism_angle
  • pixel_size
  • refined_pixel_size
  • voltage
  • spherical_aberration
  • amplitude_contrast_ratio
  • phase_shift
  • ctf_B_factor
  • micrograph_path
  • template_path
  • mip_path
  • scaled_mip_path
  • psi_path
  • theta_path
  • phi_path
  • defocus_path
  • correlation_average_path
  • correlation_variance_path
extracted_box_size tuple[int, int]

The size of the extracted particle boxes in pixels in units of pixels.

original_template_size tuple[int, int]

The original size of the template used during the matching process. Should be smaller than the extracted box size.

image_stack ExcludedTensor

The stack of images extracted from the micrographs. Is effectively a pytorch Tensor with shape (N, H, W) where N is the number of particles and (H, W) is the extracted box size.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
class ParticleStack(BaseModel2DTM):
    """Pydantic model for dealing with particle stack data.

    Attributes
    ----------
    df_path : str
        Path to the DataFrame containing the particle data. The DataFrame must have
        the following columns (see the documentation for further information):

          - mip
          - scaled_mip
          - correlation_mean
          - correlation_variance
          - total_correlations
          - pos_x
          - pos_y
          - pos_x_img
          - pos_y_img
          - pos_x_img_angstrom
          - pos_y_img_angstrom
          - psi
          - theta
          - phi
          - relative_defocus
          - refined_relative_defocus
          - defocus_u
          - defocus_v
          - astigmatism_angle
          - pixel_size
          - refined_pixel_size
          - voltage
          - spherical_aberration
          - amplitude_contrast_ratio
          - phase_shift
          - ctf_B_factor
          - micrograph_path
          - template_path
          - mip_path
          - scaled_mip_path
          - psi_path
          - theta_path
          - phi_path
          - defocus_path
          - correlation_average_path
          - correlation_variance_path

    extracted_box_size : tuple[int, int]
        The size of the extracted particle boxes in pixels in units of pixels.
    original_template_size : tuple[int, int]
        The original size of the template used during the matching process. Should be
        smaller than the extracted box size.
    image_stack : ExcludedTensor
        The stack of images extracted from the micrographs. Is effectively a pytorch
        Tensor with shape (N, H, W) where N is the number of particles and (H, W) is
        the extracted box size.
    """

    model_config: ClassVar = ConfigDict(arbitrary_types_allowed=True)

    # Serialized fields
    df_path: str
    extracted_box_size: tuple[int, int]
    original_template_size: tuple[int, int]

    # Imported tabular data (not serialized)
    _df: pd.DataFrame

    # Cropped out view of the particles from images
    image_stack: ExcludedTensor

    def __init__(self, skip_df_load: bool = False, **data: dict[str, Any]):
        """Initialize the ParticleStack object.

        Parameters
        ----------
        skip_df_load : bool, optional
            Whether to skip loading the DataFrame, by default False and the dataframe
            is loaded automatically.
        data : dict[str, Any]
            The data to initialize the object with.
        """
        super().__init__(**data)

        if not skip_df_load:
            self.load_df()

    def load_df(self) -> None:
        """Load the DataFrame from the specified path.

        Raises
        ------
        ValueError
            If the DataFrame is missing required columns.
        """
        tmp_df = pd.read_csv(self.df_path)

        # Validate the DataFrame columns
        missing_columns = [
            col for col in MATCH_TEMPLATE_DF_COLUMN_ORDER if col not in tmp_df.columns
        ]
        if missing_columns:
            raise ValueError(
                f"Missing the following columns in DataFrame: {missing_columns}"
            )

        self._df = tmp_df

    def _get_position_reference_columns(self) -> tuple[str, str]:
        """Get the position reference columns based on the DataFrame."""
        y_col = "refined_pos_y" if "refined_pos_y" in self._df.columns else "pos_y"
        x_col = "refined_pos_x" if "refined_pos_x" in self._df.columns else "pos_x"
        return y_col, x_col

    def construct_image_stack(
        self,
        pos_reference: Literal["center", "top-left"] = "top-left",
        handle_bounds: Literal["pad", "error"] = "pad",
        padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
        padding_value: float = 0.0,
    ) -> torch.Tensor:
        """Construct stack of images from the DataFrame (updates image_stack in-place).

        This method preferentially selects refined position columns by default
        (refined_pos_x, refined_pos_y) if they are present in the DataFrame, falling
        back to unrefined positions (pos_x, pos_y) otherwise.

        This method uses columns pos_x and pos_y (or refined_pos_x and refined_pos_y if
        available) to extract the boxes from the images. When using top-left reference
        position, the boxes are extracted as follows, where the dots represent the
        actual particle in the image

        Example:
            :                +----------------------------------+
            :                |                                  |
            :                |                                  |
            :                |     (x, y) *=== box_w ===+       |
            :                |            |             |       |
            :                |            |     ....  box_h     |
            :           img_height        |    ......   |       |
            :                |            |     ....    |       |
            :                |            |             |       |
            :                |            +=============+       |
            :                |                                  |
            :                +------------ img_width -----------+

        When center reference is used, then the position columns in the DataFrame are
        interpreted as the center of the particle, and the boxes are extracted around
        this x and y position as follows:

        Example:
            :                +----------------------------------+
            :                |                                  |
            :                |                                  |
            :                |            +=== box_w ===+       |
            :                |            |             |       |
            :                |            |     ....    |       |
            :           img_height        |(x, y).*.. box_h     |
            :                |            |     ....    |       |
            :                |            |             |       |
            :                |            +=============+       |
            :                |                                  |
            :                +------------ img_width -----------+

        Parameters
        ----------
        pos_reference : Literal["center", "top-left"], optional
            The reference point for the positions, by default "top-left". If "center",
            the boxes extracted will be
            image[y - box_size // 2 : y + box_size // 2, ...].
            Columns in the dataframe which are used as position references are always
            pos_x and pos_y, or refined_pos_x and refined_pos_y if available.
            If "top-left", the boxes will be image[y : y + box_size, ...].
            Leopard-EM uses the "top-left" reference position, and unless you know data
            was processed in a different way you should not change this value.
        handle_bounds : Literal["pad", "clip", "error"], optional
            How to handle the bounds of the image, by default "pad". If "pad", the image
            will be padded with the padding value based on the padding mode. If "error",
            an error will be raised if any region exceeds the image bounds. NOTE:
            clipping is not supported since returned stack may have inhomogeneous sizes.
        padding_mode : Literal["constant", "reflect", "replicate"], optional
            The padding mode to use when padding the image, by default "constant".
            "constant" pads with the value `padding_value`, "reflect" pads with the
            reflection of the image at the edge, and "replicate" pads with the last
            pixel of the image. These match the modes available in
            `torch.nn.functional.pad`.
        padding_value : float, optional
            The value to use for padding when `padding_mode` is "constant", by default
            0.0.

        Returns
        -------
        torch.Tensor
            The stack of images, this is the internal 'image_stack' attribute.
        """
        # Determine which position columns to use (refined if available)
        y_col, x_col = self._get_position_reference_columns()

        # Create an empty tensor to store the image stack
        h, w = self.original_template_size
        box_h, box_w = self.extracted_box_size
        image_stack = torch.zeros((self.num_particles, *self.extracted_box_size))

        # Find the indexes in the DataFrame that correspond to each unique image
        image_index_groups = self._df.groupby("micrograph_path").groups
        for img_path, indexes in image_index_groups.items():
            img = load_mrc_image(img_path)

            pos_y = self._df.loc[indexes, y_col].to_numpy()
            pos_x = self._df.loc[indexes, x_col].to_numpy()

            # If the position reference is "center", shift (x, y) by half the original
            # template width/height so reference is now the top-left corner
            if pos_reference == "center":
                pos_y = pos_y - h // 2
                pos_x = pos_x - w // 2

            # Our reference is now a top-left corner of a box of the original template
            # shape, BUT we want a slightly larger box of extracted_box_size AND this
            # box to be centered around the particle. Therefore, need to shift the
            # position half the difference between the original template size and
            # the extracted box size.
            pos_y -= (box_h - h) // 2
            pos_x -= (box_w - w) // 2

            pos_y = torch.tensor(pos_y)
            pos_x = torch.tensor(pos_x)

            # Code logic is simplified by only using the top-left reference position
            # in the `get_cropped_image_regions` function. Relative referencing handled
            # by the ParticleStack class.
            cropped_images = get_cropped_image_regions(
                img,
                pos_y,
                pos_x,
                self.extracted_box_size,
                pos_reference="top-left",
                handle_bounds=handle_bounds,
                padding_mode=padding_mode,
                padding_value=padding_value,
            )
            image_stack[indexes] = cropped_images

        self.image_stack = image_stack

        return image_stack

    def construct_cropped_statistic_stack(
        self,
        stat: Literal[
            "mip",
            "scaled_mip",
            "correlation_average",
            "correlation_variance",
            "defocus",
            "psi",
            "theta",
            "phi",
        ],
        handle_bounds: Literal["pad", "error"] = "pad",
        padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
        padding_value: float = 0.0,
    ) -> torch.Tensor:
        """Return a tensor of the specified statistic for each cropped image.

        NOTE: This function is very similar to `construct_image_stack` but returns the
        statistic in one of the result maps. Shape here is (N, H - h + 1, W - w + 1).

        Parameters
        ----------
        stat : Literal["mip", "scaled_mip", "correlation_average",
            "correlation_variance", "defocus", "psi", "theta", "phi"]
            The statistic to extract from the DataFrame.
        handle_bounds : Literal["pad", "clip", "error"], optional
            How to handle the bounds of the image, by default "pad". If "pad", the image
            will be padded with the padding value based on the padding mode. If "error",
            an error will be raised if any region exceeds the image bounds. NOTE:
            clipping is not supported since returned stack may have inhomogeneous sizes.
        padding_mode : Literal["constant", "reflect", "replicate"], optional
            The padding mode to use when padding the image, by default "constant".
            "constant" pads with the value `padding_value`, "reflect" pads with the
            reflection of the image at the edge, and "replicate" pads with the last
            pixel of the image. These match the modes available in
            `torch.nn.functional.pad`.
        padding_value : float, optional
            The value to use for padding when `padding_mode` is "constant", by default
            0.0.

        Returns
        -------
        torch.Tensor
            The stack of statistics with shape (N, H - h + 1, W - w + 1) where N is the
            number of particles and (H, W) is the extracted box size with (h, w) being
            the original template size.
        """
        stat_col = f"{stat}_path"
        y_col, x_col = self._get_position_reference_columns()

        if stat_col not in self._df.columns:
            raise ValueError(f"Statistic '{stat}' not found in the DataFrame.")

        # Create an empty tensor to store the stat stack
        h, w = self.original_template_size
        box_h, box_w = self.extracted_box_size
        stat_stack = torch.zeros((self.num_particles, box_h - h + 1, box_w - w + 1))

        # Find the indexes in the DataFrame that correspond to each unique stat map
        stat_index_groups = self._df.groupby(stat_col).groups

        # Loop over each unique stat map and extract the particles
        for stat_path, indexes in stat_index_groups.items():
            stat_map = load_mrc_image(stat_path)

            # with reference to the exact pixel of the statistic (top-left)
            # need to account for relative extracted box size
            pos_y = self._df.loc[indexes, y_col].to_numpy()
            pos_x = self._df.loc[indexes, x_col].to_numpy()

            # NOTE: For both references, we need to shift both x and y
            # by half the different of the original template shape and extracted box
            # so that the padding around the statistic peak is symmetric.
            pos_y -= (box_h - h) // 2
            pos_x -= (box_w - w) // 2

            pos_y = torch.tensor(pos_y)
            pos_x = torch.tensor(pos_x)

            cropped_stat_maps = get_cropped_image_regions(
                stat_map,
                pos_y,
                pos_x,
                (box_h - h + 1, box_w - w + 1),
                pos_reference="top-left",
                handle_bounds=handle_bounds,
                padding_mode=padding_mode,
                padding_value=padding_value,
            )
            stat_stack[indexes] = cropped_stat_maps

        return stat_stack

    def construct_filter_stack(
        self, preprocess_filters: PreprocessingFilters, output_shape: tuple[int, int]
    ) -> torch.Tensor:
        """Get stack of Fourier filters from filter config and reference micrographs.

        Note that here the filters are assumed to be applied globally (i.e. no local
        whitening, etc. is being done). Whitening filters are calculated with reference
        to each original micrograph in the DataFrame.

        Parameters
        ----------
        preprocess_filters : PreprocessingFilters
            Configuration object of filters to apply.
        output_shape : tuple[int, int]
            What shape along the last two dimensions the filters should be.

        Returns
        -------
        torch.Tensor
            The stack of filters with shape (N, h, w) where N is the number of particles
            and (h, w) is the output shape.
        """
        # Create an empty tensor to store the filter stack
        filter_stack = torch.zeros((self.num_particles, *output_shape))

        # Find the indexes in the DataFrame that correspond to each unique image
        image_index_groups = self._df.groupby("micrograph_path").groups

        # Loop over each unique image and extract the particles
        for img_path, indexes in image_index_groups.items():
            img = load_mrc_image(img_path)

            image_dft = torch.fft.rfftn(img)  # pylint: disable=not-callable
            image_dft[0, 0] = 0 + 0j
            cumulative_filter = preprocess_filters.get_combined_filter(
                ref_img_rfft=image_dft,
                output_shape=output_shape,
            )

            filter_stack[indexes] = cumulative_filter

        return filter_stack

    @property
    def df_columns(self) -> list[str]:
        """Get the columns of the DataFrame."""
        return list(self._df.columns.tolist())

    @property
    def num_particles(self) -> int:
        """Get the number of particles in the stack."""
        return len(self._df)

    def get_relative_defocus(
        self,
        prefer_refined_defocus: bool = True,
    ) -> torch.Tensor:
        """Get the relative defocus values for each particle.

        Parameters
        ----------
        prefer_refined_defocus : bool, optional
            Whether to use the refined defocus values (columns prefixed with 'refined_')
            or not, by default True.

        Returns
        -------
        torch.Tensor
            The relative defocus values for each particle.

        Warnings
        --------
            Warns if NaN values or no column present for either
            'refined_relative_defocus' or 'relative_defocus'.
            Falls back to the unrefined values.
        """
        rel_defocus_col = "relative_defocus"
        # Both refined columns must be present AND no values can be NaN or inf
        if prefer_refined_defocus:
            if "refined_relative_defocus" not in self._df.columns:
                warnings.warn(
                    "Refined defocus values not found in DataFrame, using original "
                    "defocus values...",
                    stacklevel=2,
                )
            elif _any_nan_or_inf(self._df["refined_relative_defocus"]):
                warnings.warn(
                    "Refined defocus values contain NaN or inf values, using original "
                    "defocus values...",
                    stacklevel=2,
                )
            else:
                rel_defocus_col = "refined_relative_defocus"

        return torch.tensor(self._df[rel_defocus_col].to_numpy())

    def get_absolute_defocus(
        self, prefer_refined_defocus: bool = True
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Get the absolute defocus values for each particle.

        NOTE: If the refined defocus values are requested but not present in the
        DataFrame (either no column or any NaN values), a user warning is raised
        and the original defocus values are returned instead.

        Parameters
        ----------
        prefer_refined_defocus : bool, optional
            Whether to use the refined defocus values
            (columns prefixed with 'refined_') or not, by default True.

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            A tuple of two tensors containing the absolute defocus values along the
            major (defocus_u) and minor axes (defocus_v), respectively in units of
            Angstroms.
        """
        particle_defocus = self.get_relative_defocus(prefer_refined_defocus)
        defocus_u = torch.tensor(self._df["defocus_u"].to_numpy()) + particle_defocus
        defocus_v = torch.tensor(self._df["defocus_v"].to_numpy()) + particle_defocus

        return defocus_u, defocus_v

    def get_pixel_size(
        self,
        prefer_refined_pixel_size: bool = True,
    ) -> torch.Tensor:
        """Get the relative pixel size values for each particle.

        Parameters
        ----------
        prefer_refined_pixel_size : bool, optional
            Whether to use the refined pixel size values
            (columns prefixed with 'refined_') or not, by default True.

        Returns
        -------
        torch.Tensor
            The relative pixel size values for each particle.

        Warnings
        --------
            Warns if NaN values or no column present for either 'refined_pixel_size'
            or 'pixel_size'. Falls back to the unrefined values.
        """
        pixel_size_col = "pixel_size"
        if prefer_refined_pixel_size:
            if "refined_pixel_size" not in self._df.columns:
                warnings.warn(
                    "Refined pixel size not found in DataFrame, using original"
                    " pixel size values...",
                    stacklevel=2,
                )
            elif _any_nan_or_inf(self._df["refined_pixel_size"]):
                warnings.warn(
                    "Refined pixel size contain NaN or inf values, using original"
                    " pixel size values...",
                    stacklevel=2,
                )
            else:
                pixel_size_col = "refined_pixel_size"

        return torch.tensor(self._df[pixel_size_col].to_numpy())

    def get_euler_angles(self, prefer_refined_angles: bool = True) -> torch.Tensor:
        """Return the Euler angles (phi, theta, psi) of all particles as a tensor.

        Parameters
        ----------
        prefer_refined_angles : bool, optional
            When true, the refined Euler angles are used (columns prefixed with
            'refined_'), otherwise the original angles are used, by default True.

        Returns
        -------
        torch.Tensor
            A tensor of shape (N, 3) where N is the number of particles and the columns
            correspond to (phi, theta, psi) in ZYZ format.
        """
        # Ensure all three refined columns are present, warning if not
        phi_col = "phi"
        theta_col = "theta"
        psi_col = "psi"
        if prefer_refined_angles:
            if not all(
                x in self._df.columns
                for x in ["refined_phi", "refined_theta", "refined_psi"]
            ):
                warnings.warn(
                    "Refined angles not found in DataFrame, using original angles...",
                    stacklevel=2,
                )
            else:
                phi_col = "refined_phi"
                theta_col = "refined_theta"
                psi_col = "refined_psi"

        # Get the angles from the DataFrame
        phi = torch.tensor(self._df[phi_col].to_numpy())
        theta = torch.tensor(self._df[theta_col].to_numpy())
        psi = torch.tensor(self._df[psi_col].to_numpy())

        return torch.stack((phi, theta, psi), dim=-1)

    def __getitem__(self, key: str) -> Any:
        """Get an item from the DataFrame."""
        try:
            return self._df[key]
        except KeyError as err:
            raise KeyError(f"Key '{key}' not found in underlying DataFrame.") from err

    def set_column(self, column_name: str, value: Any) -> None:
        """Set a column in the underlying DataFrame.

        Parameters
        ----------
        column_name : str
            The name of the column to set
        value : Any
            The value to set the column to
        """
        self._df.loc[:, column_name] = value

    def get_dataframe_copy(self) -> pd.DataFrame:
        """Return a copy of the underlying DataFrame.

        Returns
        -------
        pd.DataFrame
        A copy of the underlying DataFrame
        """
        return self._df.copy()

df_columns property

Get the columns of the DataFrame.

num_particles property

Get the number of particles in the stack.

__getitem__(key)

Get an item from the DataFrame.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
805
806
807
808
809
810
def __getitem__(self, key: str) -> Any:
    """Get an item from the DataFrame."""
    try:
        return self._df[key]
    except KeyError as err:
        raise KeyError(f"Key '{key}' not found in underlying DataFrame.") from err

__init__(skip_df_load=False, **data)

Initialize the ParticleStack object.

Parameters:

Name Type Description Default
skip_df_load bool

Whether to skip loading the DataFrame, by default False and the dataframe is loaded automatically.

False
data dict[str, Any]

The data to initialize the object with.

{}
Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def __init__(self, skip_df_load: bool = False, **data: dict[str, Any]):
    """Initialize the ParticleStack object.

    Parameters
    ----------
    skip_df_load : bool, optional
        Whether to skip loading the DataFrame, by default False and the dataframe
        is loaded automatically.
    data : dict[str, Any]
        The data to initialize the object with.
    """
    super().__init__(**data)

    if not skip_df_load:
        self.load_df()

construct_cropped_statistic_stack(stat, handle_bounds='pad', padding_mode='constant', padding_value=0.0)

Return a tensor of the specified statistic for each cropped image.

NOTE: This function is very similar to construct_image_stack but returns the statistic in one of the result maps. Shape here is (N, H - h + 1, W - w + 1).

Parameters:

Name Type Description Default
stat Literal["mip", "scaled_mip", "correlation_average",

"correlation_variance", "defocus", "psi", "theta", "phi"] The statistic to extract from the DataFrame.

required
handle_bounds Literal['pad', 'clip', 'error']

How to handle the bounds of the image, by default "pad". If "pad", the image will be padded with the padding value based on the padding mode. If "error", an error will be raised if any region exceeds the image bounds. NOTE: clipping is not supported since returned stack may have inhomogeneous sizes.

'pad'
padding_mode Literal['constant', 'reflect', 'replicate']

The padding mode to use when padding the image, by default "constant". "constant" pads with the value padding_value, "reflect" pads with the reflection of the image at the edge, and "replicate" pads with the last pixel of the image. These match the modes available in torch.nn.functional.pad.

'constant'
padding_value float

The value to use for padding when padding_mode is "constant", by default 0.0.

0.0

Returns:

Type Description
Tensor

The stack of statistics with shape (N, H - h + 1, W - w + 1) where N is the number of particles and (H, W) is the extracted box size with (h, w) being the original template size.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def construct_cropped_statistic_stack(
    self,
    stat: Literal[
        "mip",
        "scaled_mip",
        "correlation_average",
        "correlation_variance",
        "defocus",
        "psi",
        "theta",
        "phi",
    ],
    handle_bounds: Literal["pad", "error"] = "pad",
    padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
    padding_value: float = 0.0,
) -> torch.Tensor:
    """Return a tensor of the specified statistic for each cropped image.

    NOTE: This function is very similar to `construct_image_stack` but returns the
    statistic in one of the result maps. Shape here is (N, H - h + 1, W - w + 1).

    Parameters
    ----------
    stat : Literal["mip", "scaled_mip", "correlation_average",
        "correlation_variance", "defocus", "psi", "theta", "phi"]
        The statistic to extract from the DataFrame.
    handle_bounds : Literal["pad", "clip", "error"], optional
        How to handle the bounds of the image, by default "pad". If "pad", the image
        will be padded with the padding value based on the padding mode. If "error",
        an error will be raised if any region exceeds the image bounds. NOTE:
        clipping is not supported since returned stack may have inhomogeneous sizes.
    padding_mode : Literal["constant", "reflect", "replicate"], optional
        The padding mode to use when padding the image, by default "constant".
        "constant" pads with the value `padding_value`, "reflect" pads with the
        reflection of the image at the edge, and "replicate" pads with the last
        pixel of the image. These match the modes available in
        `torch.nn.functional.pad`.
    padding_value : float, optional
        The value to use for padding when `padding_mode` is "constant", by default
        0.0.

    Returns
    -------
    torch.Tensor
        The stack of statistics with shape (N, H - h + 1, W - w + 1) where N is the
        number of particles and (H, W) is the extracted box size with (h, w) being
        the original template size.
    """
    stat_col = f"{stat}_path"
    y_col, x_col = self._get_position_reference_columns()

    if stat_col not in self._df.columns:
        raise ValueError(f"Statistic '{stat}' not found in the DataFrame.")

    # Create an empty tensor to store the stat stack
    h, w = self.original_template_size
    box_h, box_w = self.extracted_box_size
    stat_stack = torch.zeros((self.num_particles, box_h - h + 1, box_w - w + 1))

    # Find the indexes in the DataFrame that correspond to each unique stat map
    stat_index_groups = self._df.groupby(stat_col).groups

    # Loop over each unique stat map and extract the particles
    for stat_path, indexes in stat_index_groups.items():
        stat_map = load_mrc_image(stat_path)

        # with reference to the exact pixel of the statistic (top-left)
        # need to account for relative extracted box size
        pos_y = self._df.loc[indexes, y_col].to_numpy()
        pos_x = self._df.loc[indexes, x_col].to_numpy()

        # NOTE: For both references, we need to shift both x and y
        # by half the different of the original template shape and extracted box
        # so that the padding around the statistic peak is symmetric.
        pos_y -= (box_h - h) // 2
        pos_x -= (box_w - w) // 2

        pos_y = torch.tensor(pos_y)
        pos_x = torch.tensor(pos_x)

        cropped_stat_maps = get_cropped_image_regions(
            stat_map,
            pos_y,
            pos_x,
            (box_h - h + 1, box_w - w + 1),
            pos_reference="top-left",
            handle_bounds=handle_bounds,
            padding_mode=padding_mode,
            padding_value=padding_value,
        )
        stat_stack[indexes] = cropped_stat_maps

    return stat_stack

construct_filter_stack(preprocess_filters, output_shape)

Get stack of Fourier filters from filter config and reference micrographs.

Note that here the filters are assumed to be applied globally (i.e. no local whitening, etc. is being done). Whitening filters are calculated with reference to each original micrograph in the DataFrame.

Parameters:

Name Type Description Default
preprocess_filters PreprocessingFilters

Configuration object of filters to apply.

required
output_shape tuple[int, int]

What shape along the last two dimensions the filters should be.

required

Returns:

Type Description
Tensor

The stack of filters with shape (N, h, w) where N is the number of particles and (h, w) is the output shape.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def construct_filter_stack(
    self, preprocess_filters: PreprocessingFilters, output_shape: tuple[int, int]
) -> torch.Tensor:
    """Get stack of Fourier filters from filter config and reference micrographs.

    Note that here the filters are assumed to be applied globally (i.e. no local
    whitening, etc. is being done). Whitening filters are calculated with reference
    to each original micrograph in the DataFrame.

    Parameters
    ----------
    preprocess_filters : PreprocessingFilters
        Configuration object of filters to apply.
    output_shape : tuple[int, int]
        What shape along the last two dimensions the filters should be.

    Returns
    -------
    torch.Tensor
        The stack of filters with shape (N, h, w) where N is the number of particles
        and (h, w) is the output shape.
    """
    # Create an empty tensor to store the filter stack
    filter_stack = torch.zeros((self.num_particles, *output_shape))

    # Find the indexes in the DataFrame that correspond to each unique image
    image_index_groups = self._df.groupby("micrograph_path").groups

    # Loop over each unique image and extract the particles
    for img_path, indexes in image_index_groups.items():
        img = load_mrc_image(img_path)

        image_dft = torch.fft.rfftn(img)  # pylint: disable=not-callable
        image_dft[0, 0] = 0 + 0j
        cumulative_filter = preprocess_filters.get_combined_filter(
            ref_img_rfft=image_dft,
            output_shape=output_shape,
        )

        filter_stack[indexes] = cumulative_filter

    return filter_stack

construct_image_stack(pos_reference='top-left', handle_bounds='pad', padding_mode='constant', padding_value=0.0)

Construct stack of images from the DataFrame (updates image_stack in-place).

This method preferentially selects refined position columns by default (refined_pos_x, refined_pos_y) if they are present in the DataFrame, falling back to unrefined positions (pos_x, pos_y) otherwise.

This method uses columns pos_x and pos_y (or refined_pos_x and refined_pos_y if available) to extract the boxes from the images. When using top-left reference position, the boxes are extracted as follows, where the dots represent the actual particle in the image

Example: : +----------------------------------+ : | | : | | : | (x, y) *=== box_w ===+ | : | | | | : | | .... box_h | : img_height | ...... | | : | | .... | | : | | | | : | +=============+ | : | | : +------------ img_width -----------+

When center reference is used, then the position columns in the DataFrame are interpreted as the center of the particle, and the boxes are extracted around this x and y position as follows:

Example: : +----------------------------------+ : | | : | | : | +=== box_w ===+ | : | | | | : | | .... | | : img_height |(x, y).*.. box_h | : | | .... | | : | | | | : | +=============+ | : | | : +------------ img_width -----------+

Parameters:

Name Type Description Default
pos_reference Literal['center', 'top-left']

The reference point for the positions, by default "top-left". If "center", the boxes extracted will be image[y - box_size // 2 : y + box_size // 2, ...]. Columns in the dataframe which are used as position references are always pos_x and pos_y, or refined_pos_x and refined_pos_y if available. If "top-left", the boxes will be image[y : y + box_size, ...]. Leopard-EM uses the "top-left" reference position, and unless you know data was processed in a different way you should not change this value.

'top-left'
handle_bounds Literal['pad', 'clip', 'error']

How to handle the bounds of the image, by default "pad". If "pad", the image will be padded with the padding value based on the padding mode. If "error", an error will be raised if any region exceeds the image bounds. NOTE: clipping is not supported since returned stack may have inhomogeneous sizes.

'pad'
padding_mode Literal['constant', 'reflect', 'replicate']

The padding mode to use when padding the image, by default "constant". "constant" pads with the value padding_value, "reflect" pads with the reflection of the image at the edge, and "replicate" pads with the last pixel of the image. These match the modes available in torch.nn.functional.pad.

'constant'
padding_value float

The value to use for padding when padding_mode is "constant", by default 0.0.

0.0

Returns:

Type Description
Tensor

The stack of images, this is the internal 'image_stack' attribute.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def construct_image_stack(
    self,
    pos_reference: Literal["center", "top-left"] = "top-left",
    handle_bounds: Literal["pad", "error"] = "pad",
    padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
    padding_value: float = 0.0,
) -> torch.Tensor:
    """Construct stack of images from the DataFrame (updates image_stack in-place).

    This method preferentially selects refined position columns by default
    (refined_pos_x, refined_pos_y) if they are present in the DataFrame, falling
    back to unrefined positions (pos_x, pos_y) otherwise.

    This method uses columns pos_x and pos_y (or refined_pos_x and refined_pos_y if
    available) to extract the boxes from the images. When using top-left reference
    position, the boxes are extracted as follows, where the dots represent the
    actual particle in the image

    Example:
        :                +----------------------------------+
        :                |                                  |
        :                |                                  |
        :                |     (x, y) *=== box_w ===+       |
        :                |            |             |       |
        :                |            |     ....  box_h     |
        :           img_height        |    ......   |       |
        :                |            |     ....    |       |
        :                |            |             |       |
        :                |            +=============+       |
        :                |                                  |
        :                +------------ img_width -----------+

    When center reference is used, then the position columns in the DataFrame are
    interpreted as the center of the particle, and the boxes are extracted around
    this x and y position as follows:

    Example:
        :                +----------------------------------+
        :                |                                  |
        :                |                                  |
        :                |            +=== box_w ===+       |
        :                |            |             |       |
        :                |            |     ....    |       |
        :           img_height        |(x, y).*.. box_h     |
        :                |            |     ....    |       |
        :                |            |             |       |
        :                |            +=============+       |
        :                |                                  |
        :                +------------ img_width -----------+

    Parameters
    ----------
    pos_reference : Literal["center", "top-left"], optional
        The reference point for the positions, by default "top-left". If "center",
        the boxes extracted will be
        image[y - box_size // 2 : y + box_size // 2, ...].
        Columns in the dataframe which are used as position references are always
        pos_x and pos_y, or refined_pos_x and refined_pos_y if available.
        If "top-left", the boxes will be image[y : y + box_size, ...].
        Leopard-EM uses the "top-left" reference position, and unless you know data
        was processed in a different way you should not change this value.
    handle_bounds : Literal["pad", "clip", "error"], optional
        How to handle the bounds of the image, by default "pad". If "pad", the image
        will be padded with the padding value based on the padding mode. If "error",
        an error will be raised if any region exceeds the image bounds. NOTE:
        clipping is not supported since returned stack may have inhomogeneous sizes.
    padding_mode : Literal["constant", "reflect", "replicate"], optional
        The padding mode to use when padding the image, by default "constant".
        "constant" pads with the value `padding_value`, "reflect" pads with the
        reflection of the image at the edge, and "replicate" pads with the last
        pixel of the image. These match the modes available in
        `torch.nn.functional.pad`.
    padding_value : float, optional
        The value to use for padding when `padding_mode` is "constant", by default
        0.0.

    Returns
    -------
    torch.Tensor
        The stack of images, this is the internal 'image_stack' attribute.
    """
    # Determine which position columns to use (refined if available)
    y_col, x_col = self._get_position_reference_columns()

    # Create an empty tensor to store the image stack
    h, w = self.original_template_size
    box_h, box_w = self.extracted_box_size
    image_stack = torch.zeros((self.num_particles, *self.extracted_box_size))

    # Find the indexes in the DataFrame that correspond to each unique image
    image_index_groups = self._df.groupby("micrograph_path").groups
    for img_path, indexes in image_index_groups.items():
        img = load_mrc_image(img_path)

        pos_y = self._df.loc[indexes, y_col].to_numpy()
        pos_x = self._df.loc[indexes, x_col].to_numpy()

        # If the position reference is "center", shift (x, y) by half the original
        # template width/height so reference is now the top-left corner
        if pos_reference == "center":
            pos_y = pos_y - h // 2
            pos_x = pos_x - w // 2

        # Our reference is now a top-left corner of a box of the original template
        # shape, BUT we want a slightly larger box of extracted_box_size AND this
        # box to be centered around the particle. Therefore, need to shift the
        # position half the difference between the original template size and
        # the extracted box size.
        pos_y -= (box_h - h) // 2
        pos_x -= (box_w - w) // 2

        pos_y = torch.tensor(pos_y)
        pos_x = torch.tensor(pos_x)

        # Code logic is simplified by only using the top-left reference position
        # in the `get_cropped_image_regions` function. Relative referencing handled
        # by the ParticleStack class.
        cropped_images = get_cropped_image_regions(
            img,
            pos_y,
            pos_x,
            self.extracted_box_size,
            pos_reference="top-left",
            handle_bounds=handle_bounds,
            padding_mode=padding_mode,
            padding_value=padding_value,
        )
        image_stack[indexes] = cropped_images

    self.image_stack = image_stack

    return image_stack

get_absolute_defocus(prefer_refined_defocus=True)

Get the absolute defocus values for each particle.

NOTE: If the refined defocus values are requested but not present in the DataFrame (either no column or any NaN values), a user warning is raised and the original defocus values are returned instead.

Parameters:

Name Type Description Default
prefer_refined_defocus bool

Whether to use the refined defocus values (columns prefixed with 'refined_') or not, by default True.

True

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of two tensors containing the absolute defocus values along the major (defocus_u) and minor axes (defocus_v), respectively in units of Angstroms.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
def get_absolute_defocus(
    self, prefer_refined_defocus: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Get the absolute defocus values for each particle.

    NOTE: If the refined defocus values are requested but not present in the
    DataFrame (either no column or any NaN values), a user warning is raised
    and the original defocus values are returned instead.

    Parameters
    ----------
    prefer_refined_defocus : bool, optional
        Whether to use the refined defocus values
        (columns prefixed with 'refined_') or not, by default True.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of two tensors containing the absolute defocus values along the
        major (defocus_u) and minor axes (defocus_v), respectively in units of
        Angstroms.
    """
    particle_defocus = self.get_relative_defocus(prefer_refined_defocus)
    defocus_u = torch.tensor(self._df["defocus_u"].to_numpy()) + particle_defocus
    defocus_v = torch.tensor(self._df["defocus_v"].to_numpy()) + particle_defocus

    return defocus_u, defocus_v

get_dataframe_copy()

Return a copy of the underlying DataFrame.

Returns:

Type Description
DataFrame
A copy of the underlying DataFrame
Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
824
825
826
827
828
829
830
831
832
def get_dataframe_copy(self) -> pd.DataFrame:
    """Return a copy of the underlying DataFrame.

    Returns
    -------
    pd.DataFrame
    A copy of the underlying DataFrame
    """
    return self._df.copy()

get_euler_angles(prefer_refined_angles=True)

Return the Euler angles (phi, theta, psi) of all particles as a tensor.

Parameters:

Name Type Description Default
prefer_refined_angles bool

When true, the refined Euler angles are used (columns prefixed with 'refined_'), otherwise the original angles are used, by default True.

True

Returns:

Type Description
Tensor

A tensor of shape (N, 3) where N is the number of particles and the columns correspond to (phi, theta, psi) in ZYZ format.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def get_euler_angles(self, prefer_refined_angles: bool = True) -> torch.Tensor:
    """Return the Euler angles (phi, theta, psi) of all particles as a tensor.

    Parameters
    ----------
    prefer_refined_angles : bool, optional
        When true, the refined Euler angles are used (columns prefixed with
        'refined_'), otherwise the original angles are used, by default True.

    Returns
    -------
    torch.Tensor
        A tensor of shape (N, 3) where N is the number of particles and the columns
        correspond to (phi, theta, psi) in ZYZ format.
    """
    # Ensure all three refined columns are present, warning if not
    phi_col = "phi"
    theta_col = "theta"
    psi_col = "psi"
    if prefer_refined_angles:
        if not all(
            x in self._df.columns
            for x in ["refined_phi", "refined_theta", "refined_psi"]
        ):
            warnings.warn(
                "Refined angles not found in DataFrame, using original angles...",
                stacklevel=2,
            )
        else:
            phi_col = "refined_phi"
            theta_col = "refined_theta"
            psi_col = "refined_psi"

    # Get the angles from the DataFrame
    phi = torch.tensor(self._df[phi_col].to_numpy())
    theta = torch.tensor(self._df[theta_col].to_numpy())
    psi = torch.tensor(self._df[psi_col].to_numpy())

    return torch.stack((phi, theta, psi), dim=-1)

get_pixel_size(prefer_refined_pixel_size=True)

Get the relative pixel size values for each particle.

Parameters:

Name Type Description Default
prefer_refined_pixel_size bool

Whether to use the refined pixel size values (columns prefixed with 'refined_') or not, by default True.

True

Returns:

Type Description
Tensor

The relative pixel size values for each particle.

Warnings
Warns if NaN values or no column present for either 'refined_pixel_size'
or 'pixel_size'. Falls back to the unrefined values.
Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def get_pixel_size(
    self,
    prefer_refined_pixel_size: bool = True,
) -> torch.Tensor:
    """Get the relative pixel size values for each particle.

    Parameters
    ----------
    prefer_refined_pixel_size : bool, optional
        Whether to use the refined pixel size values
        (columns prefixed with 'refined_') or not, by default True.

    Returns
    -------
    torch.Tensor
        The relative pixel size values for each particle.

    Warnings
    --------
        Warns if NaN values or no column present for either 'refined_pixel_size'
        or 'pixel_size'. Falls back to the unrefined values.
    """
    pixel_size_col = "pixel_size"
    if prefer_refined_pixel_size:
        if "refined_pixel_size" not in self._df.columns:
            warnings.warn(
                "Refined pixel size not found in DataFrame, using original"
                " pixel size values...",
                stacklevel=2,
            )
        elif _any_nan_or_inf(self._df["refined_pixel_size"]):
            warnings.warn(
                "Refined pixel size contain NaN or inf values, using original"
                " pixel size values...",
                stacklevel=2,
            )
        else:
            pixel_size_col = "refined_pixel_size"

    return torch.tensor(self._df[pixel_size_col].to_numpy())

get_relative_defocus(prefer_refined_defocus=True)

Get the relative defocus values for each particle.

Parameters:

Name Type Description Default
prefer_refined_defocus bool

Whether to use the refined defocus values (columns prefixed with 'refined_') or not, by default True.

True

Returns:

Type Description
Tensor

The relative defocus values for each particle.

Warnings
Warns if NaN values or no column present for either
'refined_relative_defocus' or 'relative_defocus'.
Falls back to the unrefined values.
Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
def get_relative_defocus(
    self,
    prefer_refined_defocus: bool = True,
) -> torch.Tensor:
    """Get the relative defocus values for each particle.

    Parameters
    ----------
    prefer_refined_defocus : bool, optional
        Whether to use the refined defocus values (columns prefixed with 'refined_')
        or not, by default True.

    Returns
    -------
    torch.Tensor
        The relative defocus values for each particle.

    Warnings
    --------
        Warns if NaN values or no column present for either
        'refined_relative_defocus' or 'relative_defocus'.
        Falls back to the unrefined values.
    """
    rel_defocus_col = "relative_defocus"
    # Both refined columns must be present AND no values can be NaN or inf
    if prefer_refined_defocus:
        if "refined_relative_defocus" not in self._df.columns:
            warnings.warn(
                "Refined defocus values not found in DataFrame, using original "
                "defocus values...",
                stacklevel=2,
            )
        elif _any_nan_or_inf(self._df["refined_relative_defocus"]):
            warnings.warn(
                "Refined defocus values contain NaN or inf values, using original "
                "defocus values...",
                stacklevel=2,
            )
        else:
            rel_defocus_col = "refined_relative_defocus"

    return torch.tensor(self._df[rel_defocus_col].to_numpy())

load_df()

Load the DataFrame from the specified path.

Raises:

Type Description
ValueError

If the DataFrame is missing required columns.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def load_df(self) -> None:
    """Load the DataFrame from the specified path.

    Raises
    ------
    ValueError
        If the DataFrame is missing required columns.
    """
    tmp_df = pd.read_csv(self.df_path)

    # Validate the DataFrame columns
    missing_columns = [
        col for col in MATCH_TEMPLATE_DF_COLUMN_ORDER if col not in tmp_df.columns
    ]
    if missing_columns:
        raise ValueError(
            f"Missing the following columns in DataFrame: {missing_columns}"
        )

    self._df = tmp_df

set_column(column_name, value)

Set a column in the underlying DataFrame.

Parameters:

Name Type Description Default
column_name str

The name of the column to set

required
value Any

The value to set the column to

required
Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
812
813
814
815
816
817
818
819
820
821
822
def set_column(self, column_name: str, value: Any) -> None:
    """Set a column in the underlying DataFrame.

    Parameters
    ----------
    column_name : str
        The name of the column to set
    value : Any
        The value to set the column to
    """
    self._df.loc[:, column_name] = value

get_cropped_image_regions(image, pos_y, pos_x, box_size, pos_reference='top-left', handle_bounds='pad', padding_mode='constant', padding_value=0.0)

Extracts regions from an image into a stack of cropped images.

The pos_reference argument determines how the (y, x) coordinates are interpreted when extracting boxes:

  • If pos_reference="center": The (y, x) coordinate refers to the center of the box. The box extends from (y - height // 2, x - width // 2) to (y + height // 2, x + width // 2).

    Example: : +------------------+ : | | : height * (y, x) | : | | : +------ width -----+

  • If pos_reference="top-left": The (y, x) coordinate refers to the top-left corner of the box. The box extends from (y, x) to (y + height, x + width).

    Example: : (y, x) *------ width -----+ : | | : | height : | | : +------------------+

Parameters:

Name Type Description Default
image Tensor | ndarray

The input image from which to extract the regions.

required
pos_y Tensor | ndarray

The y positions of the regions to extract. Type must mach image

required
pos_x Tensor | ndarray

The x positions of the regions to extract. Type must mach image

required
box_size int | tuple[int, int]

The size of the box to extract. If an integer is passed, the box will be square.

required
pos_reference Literal['center', 'top-left']

The reference point for the positions, by default "center". If "center", the boxes extracted will be image[y - box_size // 2 : y + box_size // 2, ...]. If "top-left", the boxes will be image[y : y + box_size, ...].

'top-left'
handle_bounds Literal['pad', 'clip', 'error']

How to handle the bounds of the image, by default "pad". If "pad", the image will be padded with the padding value based on the padding mode. If "error", an error will be raised if any region exceeds the image bounds. Note clipping is not supported since returned stack may have inhomogeneous sizes.

'pad'
padding_mode Literal['constant', 'reflect', 'replicate']

The padding mode to use when padding the image, by default "constant". "constant" pads with the value padding_value, "reflect" pads with the reflection of the image at the edge, and "replicate" pads with the last pixel of the image. These match the modes available in torch.nn.functional.pad.

'constant'
padding_value float

The value to use for padding when padding_mode is "constant", by default 0.0.

0.0

Returns:

Type Description
Tensor | ndarray

The stack of cropped images extracted from the input image. Type will match the input image type.

Raises:

Type Description
ValueError

If pos_reference is not one of "center" or "top-left", or if image is not a torch.Tensor or np.ndarray.

Source code in src/leopard_em/pydantic_models/data_structures/particle_stack.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_cropped_image_regions(
    image: torch.Tensor | np.ndarray,
    pos_y: torch.Tensor | np.ndarray,
    pos_x: torch.Tensor | np.ndarray,
    box_size: int | tuple[int, int],
    pos_reference: Literal["center", "top-left"] = "top-left",
    handle_bounds: Literal["pad", "error"] = "pad",
    padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
    padding_value: float = 0.0,
) -> torch.Tensor | np.ndarray:
    """Extracts regions from an image into a stack of cropped images.

    The `pos_reference` argument determines how the (y, x) coordinates are interpreted
    when extracting boxes:

    - If ``pos_reference="center"``:
        The (y, x) coordinate refers to the **center** of the box.
        The box extends from (y - height // 2, x - width // 2) to
        (y + height // 2, x + width // 2).

        Example:
            :                +------------------+
            :                |                  |
            :              height      * (y, x) |
            :                |                  |
            :                +------ width -----+

    - If ``pos_reference="top-left"``:
        The (y, x) coordinate refers to the **top-left corner** of the box.
        The box extends from (y, x) to (y + height, x + width).

        Example:
            :         (y, x) *------ width -----+
            :                |                  |
            :                |                height
            :                |                  |
            :                +------------------+

    Parameters
    ----------
    image : torch.Tensor | np.ndarray
        The input image from which to extract the regions.
    pos_y : torch.Tensor | np.ndarray
        The y positions of the regions to extract. Type must mach `image`
    pos_x : torch.Tensor | np.ndarray
        The x positions of the regions to extract. Type must mach `image`
    box_size : int | tuple[int, int]
        The size of the box to extract. If an integer is passed, the box will be square.
    pos_reference : Literal["center", "top-left"], optional
        The reference point for the positions, by default "center". If "center", the
        boxes extracted will be image[y - box_size // 2 : y + box_size // 2, ...]. If
        "top-left", the boxes will be image[y : y + box_size, ...].
    handle_bounds : Literal["pad", "clip", "error"], optional
        How to handle the bounds of the image, by default "pad". If "pad", the image
        will be padded with the padding value based on the padding mode. If "error", an
        error will be raised if any region exceeds the image bounds. Note clipping is
        not supported since returned stack may have inhomogeneous sizes.
    padding_mode : Literal["constant", "reflect", "replicate"], optional
        The padding mode to use when padding the image, by default "constant".
        "constant" pads with the value `padding_value`, "reflect" pads with the
        reflection of the image at the edge, and "replicate" pads with the last pixel
        of the image. These match the modes available in `torch.nn.functional.pad`.
    padding_value : float, optional
        The value to use for padding when `padding_mode` is "constant", by default 0.0.

    Returns
    -------
    torch.Tensor | np.ndarray
        The stack of cropped images extracted from the input image. Type will match the
        input image type.

    Raises
    ------
    ValueError
        If `pos_reference` is not one of "center" or "top-left", or if `image` is not a
        torch.Tensor or np.ndarray.
    """
    if isinstance(box_size, int):
        box_size = (box_size, box_size)

    # The underlying numpy/torch functions only operate on the top-left corner
    # reference, so shift the position half a box height/width if using center.
    if pos_reference == "center":
        pos_y = pos_y - box_size[0] // 2
        pos_x = pos_x - box_size[1] // 2
    elif pos_reference == "top-left":
        pass
    else:
        raise ValueError(f"Unknown pos_reference: {pos_reference}")

    if isinstance(image, torch.Tensor):
        return _get_cropped_image_regions_torch(
            image=image,
            pos_y=pos_y,
            pos_x=pos_x,
            box_size=box_size,
            handle_bounds=handle_bounds,
            padding_mode=padding_mode,
            padding_value=padding_value,
        )

    if isinstance(image, np.ndarray):
        padding_mode_np = TORCH_TO_NUMPY_PADDING_MODE[padding_mode]
        return _get_cropped_image_regions_numpy(
            image=image,
            pos_y=pos_y,
            pos_x=pos_x,
            box_size=box_size,
            handle_bounds=handle_bounds,
            padding_mode=padding_mode_np,
            padding_value=padding_value,
        )

    raise ValueError(f"Unknown image type: {type(image)}")