Source code for proteopy.pl.stats

import warnings
from pathlib import Path
from typing import Any, Sequence
import uuid

import numpy as np
import pandas as pd
import anndata as ad
from pandas.api.types import is_string_dtype, is_categorical_dtype
from scipy import sparse
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import leaves_list, linkage
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.patches import Patch
from matplotlib.ticker import MaxNLocator
import seaborn as sns

from proteopy.utils.anndata import check_proteodata, is_proteodata
from proteopy.utils.matplotlib import _resolve_color_scheme
from proteopy.utils.functools import partial_with_docsig
from proteopy.utils.string import sanitize_string
from proteopy.pp.stats import calculate_cv


def _validate_completeness_args(  # noqa: C901
    adata,
    axis,
    layer,
    order,
    group_by_resolution,
    group_by_partition,
    min_count,
    min_fraction,
    fraction_thresh,
    bin_width,
):
    """Validate inputs and derive working variables for completeness."""
    check_proteodata(adata)

    if axis not in (0, 1):
        raise ValueError(
            "`axis` must be either 0 (var) or 1 (obs)."
        )

    if (
        group_by_resolution is not None
        and group_by_partition is not None
    ):
        raise ValueError(
            "`group_by_resolution` and `group_by_partition` "
            "are mutually exclusive. Provide one or neither."
        )

    if min_count is not None and min_fraction is not None:
        raise ValueError(
            "`min_count` and `min_fraction` are mutually exclusive. "
            "Provide one or neither."
        )

    if fraction_thresh is not None and (
        fraction_thresh < 0 or fraction_thresh > 1
    ):
        raise ValueError(
            "`fraction_thresh` must be between 0 and 1."
        )

    if bin_width is not None and bin_width <= 0:
        raise ValueError(
            "`bin_width` must be a positive number."
        )

    if (
        group_by_resolution is None
        and (min_count is not None or min_fraction is not None)
    ):
        warnings.warn(
            "`min_count` and `min_fraction` are only used when "
            "`group_by_resolution` is provided. They will be "
            "ignored."
        )
        min_count = None
        min_fraction = None

    if layer is None:
        matrix = adata.X
    else:
        if layer not in adata.layers:
            raise KeyError(
                f"Layer '{layer}' not found in adata.layers."
            )
        matrix = adata.layers[layer]

    if matrix is None:
        raise ValueError(
            "Selected matrix is empty; cannot compute "
            "completeness."
        )

    n_obs, n_vars = matrix.shape

    if axis == 0:
        axis_labels = ("var", "obs")
        n_items = n_vars
        axis_length = n_obs
        grouping_frame = adata.obs
    else:
        axis_labels = ("obs", "var")
        n_items = n_obs
        axis_length = n_vars
        grouping_frame = adata.var

    if axis_length == 0:
        raise ValueError(
            "Cannot compute completeness on empty axis."
        )

    if n_items == 0:
        raise ValueError(
            "No items to compute completeness for."
        )

    if order is not None and group_by_partition is None:
        warnings.warn(
            "`order` is only used when "
            "`group_by_partition` is provided. "
            "It will be ignored."
        )

    return [
        matrix, axis_labels, n_items, axis_length,
        grouping_frame, min_count, min_fraction,
    ]


def _summary_stats(values):
    """Return a single-row DataFrame of summary statistics."""
    s = pd.Series(values) if not isinstance(
        values, pd.Series,
    ) else values
    return pd.DataFrame({
        "count": [s.count()],
        "mean": [s.mean()],
        "median": [s.median()],
        "std": [s.std()],
        "min": [s.min()],
        "max": [s.max()],
    })


def _count_nonmissing(mat, ax, zero_to_na):
    """Count non-missing values along the given axis."""
    if sparse.issparse(mat):
        mat_coo = mat.tocoo()
        data = mat_coo.data
        rows = mat_coo.row
        cols = mat_coo.col
        if zero_to_na:
            valid = (~np.isnan(data)) & (data != 0)
            if ax == 0:
                return np.bincount(
                    cols[valid],
                    minlength=mat.shape[1],
                )
            else:
                return np.bincount(
                    rows[valid],
                    minlength=mat.shape[0],
                )
        else:
            nan_mask = np.isnan(data)
            if ax == 0:
                nan_c = np.bincount(
                    cols[nan_mask],
                    minlength=mat.shape[1],
                )
                return mat.shape[0] - nan_c
            else:
                nan_c = np.bincount(
                    rows[nan_mask],
                    minlength=mat.shape[0],
                )
                return mat.shape[1] - nan_c
    else:
        values = np.asarray(mat)
        valid_mask = ~np.isnan(values)
        if zero_to_na:
            valid_mask &= values != 0
        return valid_mask.sum(axis=ax)


def _resolve_partition_order(order, available):
    """Resolve and validate group order for partition plots."""
    if order is not None:
        if isinstance(order, str):
            order = [order]
        else:
            order = list(order)
        missing = [
            g for g in order if g not in available
        ]
        if missing:
            raise ValueError(
                "Unknown group(s) in `order`: "
                f"{', '.join(map(str, missing))}.",
            )
        return order
    return sorted(available, key=str)


def _group_completeness_counts(
    matrix, axis, g_mask, zero_to_na,
):
    """Count non-missing values per item within a group mask."""
    if axis == 0:
        sub_matrix = matrix[g_mask, :]
    else:
        sub_matrix = matrix[:, g_mask]
    counts = np.asarray(
        _count_nonmissing(sub_matrix, axis, zero_to_na),
        dtype=float,
    )
    return counts, int(g_mask.sum())


def _plot_completeness_partition(
    matrix,
    axis,
    axis_labels,
    zero_to_na,
    grouping_frame,
    group_by_partition,
    order,
    fraction_thresh,
    print_stats,
    xlabel_rotation,
    figsize,
    ax,
):
    """Plot boxplots of completeness partitioned by a grouping column."""
    if group_by_partition not in grouping_frame.columns:
        raise KeyError(
            f"Column '{group_by_partition}' not found "
            f"in {'.obs' if axis == 0 else '.var'}",
        )

    group_series = grouping_frame[group_by_partition]
    available = list(group_series.dropna().unique())
    unique_groups = _resolve_partition_order(
        order, available,
    )

    if len(unique_groups) == 0:
        raise ValueError(
            "No groups found for the given "
            "`group_by_partition` column.",
        )

    # -- compute completeness per item within each group
    records = []
    for g in unique_groups:
        g_mask = (group_series == g).values
        counts_g, g_size = _group_completeness_counts(
            matrix, axis, g_mask, zero_to_na,
        )
        fracs = counts_g / g_size
        for f in fracs:
            records.append(
                {"Group": str(g), "Completeness": f}
            )

    long_df = pd.DataFrame(records)

    if print_stats:
        print("Global:")
        print(_summary_stats(
            long_df["Completeness"],
        ).to_string(
            index=False, float_format="%.4f",
        ))
        per_group = (
            long_df.groupby("Group")["Completeness"]
            .agg(["count", "mean", "median",
                  "std", "min", "max"])
            .reindex(
                [str(g) for g in unique_groups],
            )
        )
        print(f"\nPer {group_by_partition}:")
        print(per_group.to_string(
            float_format="%.4f",
        ))
        print()

    if ax is None:
        fig, _ax = plt.subplots(figsize=figsize)
    else:
        _ax = ax
        fig = _ax.get_figure()
    sns.boxplot(
        data=long_df,
        x="Group",
        y="Completeness",
        order=[str(g) for g in unique_groups],
        ax=_ax,
    )
    _ax.set_title(
        f"Completeness per {axis_labels[0]} "
        f"by '{group_by_partition}'",
    )
    _ax.set_xlabel(group_by_partition)
    _ax.set_ylabel(
        f"Fraction of non-missing {axis_labels[1]} "
        f"values per {axis_labels[0]}",
    )
    if fraction_thresh is not None:
        _ax.axhline(
            fraction_thresh,
            color="red",
            linestyle="--",
            label=f"fraction_thresh={fraction_thresh}",
        )
        _ax.legend()
    plt.setp(
        _ax.get_xticklabels(),
        rotation=xlabel_rotation,
    )
    return fig, _ax


def _plot_completeness_ungrouped(
    matrix,
    axis,
    axis_labels,
    axis_length,
    zero_to_na,
    fraction_thresh,
    print_stats,
    bin_edges,
    xlabel_rotation,
    figsize,
    ax,
):
    """Plot a histogram of ungrouped completeness fractions."""
    counts = np.asarray(
        _count_nonmissing(matrix, axis, zero_to_na),
        dtype=float,
    )
    fractions = counts / axis_length

    if print_stats:
        print("Global:")
        print(_summary_stats(fractions).to_string(
            index=False, float_format="%.4f",
        ))
        print()

    if ax is None:
        fig, _ax = plt.subplots(figsize=figsize)
    else:
        _ax = ax
        fig = _ax.get_figure()
    sns.histplot(fractions, bins=bin_edges, ax=_ax)
    _ax.set_title(
        f"Completeness per {axis_labels[0]}",
    )
    _ax.set_xlabel(
        f"Fraction of non-missing {axis_labels[1]} values "
        f"per {axis_labels[0]}",
    )
    if fraction_thresh is not None:
        _ax.axvline(
            fraction_thresh,
            color="red",
            linestyle="--",
            label=f"fraction_thresh={fraction_thresh}",
        )
        _ax.legend()
    plt.setp(
        _ax.get_xticklabels(), rotation=xlabel_rotation,
    )
    return fig, _ax


def _plot_completeness_resolution(
    matrix,
    axis,
    axis_labels,
    n_items,
    zero_to_na,
    grouping_frame,
    group_by_resolution,
    min_count,
    min_fraction,
    fraction_thresh,
    print_stats,
    bin_edges,
    xlabel_rotation,
    figsize,
    ax,
):
    """Plot a histogram of detection fractions across groups."""
    if group_by_resolution not in grouping_frame.columns:
        raise KeyError(
            f"Column '{group_by_resolution}' not found in "
            f"{'.obs' if axis == 0 else '.var'}",
        )

    group_series = grouping_frame[group_by_resolution]
    unique_groups = list(
        group_series.dropna().unique()
    )
    n_groups = len(unique_groups)

    if n_groups == 0:
        raise ValueError(
            "No groups found for the given "
            "`group_by_resolution` column.",
        )

    # Default threshold: min_count=1
    use_fraction = min_fraction is not None
    if not use_fraction and min_count is None:
        min_count = 1

    # For each group, determine which items are "detected"
    detected_count = np.zeros(n_items, dtype=int)

    for g in unique_groups:
        g_mask = (group_series == g).values
        counts_g, group_size = _group_completeness_counts(
            matrix, axis, g_mask, zero_to_na,
        )

        if use_fraction:
            detected = (
                counts_g / group_size >= min_fraction
            )
        else:
            detected = counts_g >= min_count

        detected_count += detected.astype(int)

    detection_fractions = detected_count / n_groups

    if print_stats:
        print("Global:")
        print(_summary_stats(
            detection_fractions,
        ).to_string(
            index=False, float_format="%.4f",
        ))
        print()

    if ax is None:
        fig, _ax = plt.subplots(figsize=figsize)
    else:
        _ax = ax
        fig = _ax.get_figure()
    sns.histplot(
        detection_fractions, bins=bin_edges, ax=_ax,
    )

    if use_fraction:
        threshold_label = (
            f"min_fraction={min_fraction}"
        )
    else:
        threshold_label = f"min_count={min_count}"

    _ax.set_title(
        f"'{group_by_resolution}' completeness "
        f"per {axis_labels[0]}",
    )
    _ax.set_xlabel(
        f"Fraction of '{group_by_resolution}' groups "
        f"where {axis_labels[0]} is detected "
        f"({threshold_label})",
    )
    if fraction_thresh is not None:
        _ax.axvline(
            fraction_thresh,
            color="red",
            linestyle="--",
            label=f"fraction_thresh={fraction_thresh}",
        )
        _ax.legend()
    plt.setp(
        _ax.get_xticklabels(), rotation=xlabel_rotation,
    )
    return fig, _ax


def completeness(
    adata: ad.AnnData,
    axis: int,
    layer: str | None = None,
    zero_to_na: bool = False,
    order: Sequence[Any] | None = None,
    group_by_partition: str | None = None,
    group_by_resolution: str | None = None,
    min_count: int | None = None,
    min_fraction: float | None = None,
    fraction_thresh: float | None = None,
    print_stats: bool = False,
    bin_width: float = 0.01,
    xlabel_rotation: float = 0.0,
    figsize: tuple[float, float] = (6.0, 5.0),
    show: bool = True,
    ax: Axes | None = None,
    save: str | Path | None = None,
) -> Axes:
    """
    Plot a histogram of completeness across observations or variables.

    When ``group_by_resolution`` is provided, shows the distribution of
    the fraction of groups in which each item is "detected" (has at
    least ``min_count`` or ``min_fraction`` non-missing values within
    the group).

    Parameters
    ----------
    adata : AnnData
        :class:`~anndata.AnnData` object in proteodata format.
    axis
        ``0`` plots completeness per variable, ``1`` per observation.
    layer
        Name of the layer to use instead of ``.X``.
    zero_to_na
        Treat zero entries as missing values when True.
    order
        Explicit ordering and subsetting of groups when
        ``group_by_partition`` is provided. Groups not listed
        are excluded.
    group_by_partition
        Column in ``.obs`` (axis 0) or ``.var`` (axis 1) used to
        partition the opposite axis. For each partition group,
        completeness fractions are computed per item and displayed
        as side-by-side boxplots. Mutually exclusive with
        ``group_by_resolution``.
    group_by_resolution
        Column in ``.obs`` (axis 0) or ``.var`` (axis 1) used to define
        detection groups. When provided, the plot shows the fraction of
        groups in which each item is detected.
    min_count : int or None, optional
        Minimum number of non-missing values within a group for an item
        to be considered detected. Mutually exclusive with
        ``min_fraction``. Only used when ``group_by_resolution`` is
        provided.
    min_fraction : float or None, optional
        Minimum fraction of non-missing values within a group for an
        item to be considered detected. Mutually exclusive with
        ``min_count``. Only used when ``group_by_resolution`` is
        provided.
    fraction_thresh : float or None, optional
        Completeness fraction threshold in ``[0, 1]``. Drawn as a
        vertical dashed line on histograms or a horizontal dashed
        line on boxplots (``group_by_partition``).
    print_stats : bool, optional
        Print completeness distribution statistics before plotting.
        When ``group_by_partition`` is provided, per-group statistics
        are printed below the global summary.
    bin_width : float, optional
        Width of each histogram bin on the fraction axis. Bins span
        from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01.
    xlabel_rotation
        Rotation angle in degrees applied to x-axis tick labels.
    figsize
        Tuple ``(width, height)`` controlling figure size in inches.
    show
        Display the plot with ``plt.show()`` when True.
    ax : Axes or None, optional
        Matplotlib Axes object to plot onto. If ``None``, a new
        figure and axes are created.
    save : str or Path or None, optional
        File path to save the figure. If ``None``, do not save.

    Returns
    -------
    Axes
        The Matplotlib Axes object used for plotting.
    """
    validated = _validate_completeness_args(
        adata, axis, layer, order,
        group_by_resolution, group_by_partition,
        min_count, min_fraction, fraction_thresh,
        bin_width,
    )
    matrix = validated[0]
    axis_labels = validated[1]
    n_items = validated[2]
    axis_length = validated[3]
    grouping_frame = validated[4]
    min_count = validated[5]
    min_fraction = validated[6]

    bin_edges = np.arange(
        0.0, 1.0 + bin_width * 2, bin_width,
    )

    if group_by_partition is not None:
        fig, _ax = _plot_completeness_partition(
            matrix=matrix,
            axis=axis,
            axis_labels=axis_labels,
            zero_to_na=zero_to_na,
            grouping_frame=grouping_frame,
            group_by_partition=group_by_partition,
            order=order,
            fraction_thresh=fraction_thresh,
            print_stats=print_stats,
            xlabel_rotation=xlabel_rotation,
            figsize=figsize,
            ax=ax,
        )
    elif group_by_resolution is None:
        fig, _ax = _plot_completeness_ungrouped(
            matrix=matrix,
            axis=axis,
            axis_labels=axis_labels,
            axis_length=axis_length,
            zero_to_na=zero_to_na,
            fraction_thresh=fraction_thresh,
            print_stats=print_stats,
            bin_edges=bin_edges,
            xlabel_rotation=xlabel_rotation,
            figsize=figsize,
            ax=ax,
        )
    else:
        fig, _ax = _plot_completeness_resolution(
            matrix=matrix,
            axis=axis,
            axis_labels=axis_labels,
            n_items=n_items,
            zero_to_na=zero_to_na,
            grouping_frame=grouping_frame,
            group_by_resolution=group_by_resolution,
            min_count=min_count,
            min_fraction=min_fraction,
            fraction_thresh=fraction_thresh,
            print_stats=print_stats,
            bin_edges=bin_edges,
            xlabel_rotation=xlabel_rotation,
            figsize=figsize,
            ax=ax,
        )

    if save is not None:
        fig.savefig(save, dpi=300, bbox_inches="tight")
    if show:
        plt.show()

    return _ax


[docs] def completeness_per_var( adata: ad.AnnData, layer: str | None = None, zero_to_na: bool = False, order: Sequence[Any] | None = None, group_by_partition: str | None = None, group_by_resolution: str | None = None, min_count: int | None = None, min_fraction: float | None = None, fraction_thresh: float | None = None, print_stats: bool = False, bin_width: float = 0.01, xlabel_rotation: float = 0.0, figsize: tuple[float, float] = (6.0, 5.0), show: bool = True, ax: Axes | None = None, save: str | Path | None = None, ) -> Axes: """ Plot a histogram of completeness per variable. For each variable (column), completeness is the fraction of observations (rows) with non-missing values. When ``group_by_resolution`` is provided, shows the fraction of observation-groups in which each variable is detected. When ``group_by_partition`` is provided, shows boxplots of per-variable completeness within each partition group. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` object in proteodata format. layer Name of the layer to use instead of ``.X``. zero_to_na Treat zero entries as missing values when True. order Explicit ordering and subsetting of groups when ``group_by_partition`` is provided. Groups not listed are excluded. group_by_partition Column in ``.obs`` used to partition observations. For each group, completeness fractions are computed per variable and displayed as side-by-side boxplots. Mutually exclusive with ``group_by_resolution``. group_by_resolution Column in ``.obs`` used to define detection groups. When provided, the plot shows the fraction of groups in which each variable is detected. min_count : int or None, optional Minimum number of non-missing observations within a group for a variable to be considered detected. Mutually exclusive with ``min_fraction``. Only used when ``group_by_resolution`` is provided. min_fraction : float or None, optional Minimum fraction of non-missing observations within a group for a variable to be considered detected. Mutually exclusive with ``min_count``. Only used when ``group_by_resolution`` is provided. fraction_thresh : float or None, optional Completeness fraction threshold in ``[0, 1]``. Drawn as a vertical dashed line on histograms or a horizontal dashed line on boxplots (``group_by_partition``). print_stats : bool, optional Print completeness distribution statistics before plotting. When ``group_by_partition`` is provided, per-group statistics are printed below the global summary. bin_width : float, optional Width of each histogram bin on the fraction axis. Bins span from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01. xlabel_rotation Rotation angle in degrees applied to x-axis tick labels. figsize Tuple ``(width, height)`` controlling figure size in inches. show Display the plot with ``plt.show()`` when True. ax : Axes or None, optional Matplotlib Axes object to plot onto. If ``None``, a new figure and axes are created. save : str or Path or None, optional File path to save the figure. If ``None``, do not save. Returns ------- Axes The Matplotlib Axes object used for plotting. Examples -------- >>> import proteopy as pr >>> adata = pr.datasets.example_peptide_data() >>> pr.pl.completeness_per_var(adata, fraction_thresh=0.7) >>> pr.pl.completeness_per_var( ... adata, ... group_by_resolution="condition", ... min_count=1, ... ) >>> pr.pl.completeness_per_var( ... adata, ... group_by_partition="condition", ... order=["control", "treatment"], ... ) """ return completeness( adata, axis=0, layer=layer, zero_to_na=zero_to_na, order=order, group_by_partition=group_by_partition, group_by_resolution=group_by_resolution, min_count=min_count, min_fraction=min_fraction, fraction_thresh=fraction_thresh, print_stats=print_stats, bin_width=bin_width, xlabel_rotation=xlabel_rotation, figsize=figsize, show=show, ax=ax, save=save, )
[docs] def completeness_per_sample( adata: ad.AnnData, layer: str | None = None, zero_to_na: bool = False, order: Sequence[Any] | None = None, group_by_partition: str | None = None, group_by_resolution: str | None = None, min_count: int | None = None, min_fraction: float | None = None, fraction_thresh: float | None = None, print_stats: bool = False, bin_width: float = 0.01, xlabel_rotation: float = 0.0, figsize: tuple[float, float] = (6.0, 5.0), show: bool = True, ax: Axes | None = None, save: str | Path | None = None, ) -> Axes: """ Plot a histogram of completeness per sample (observation). For each sample (row), completeness is the fraction of variables (columns) with non-missing values. When ``group_by_resolution`` is provided, shows the fraction of variable-groups in which each sample is detected. When ``group_by_partition`` is provided, shows boxplots of per-sample completeness within each partition group. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` object in proteodata format. layer Name of the layer to use instead of ``.X``. zero_to_na Treat zero entries as missing values when True. order Explicit ordering and subsetting of groups when ``group_by_partition`` is provided. Groups not listed are excluded. group_by_partition Column in ``.var`` used to partition variables. For each group, completeness fractions are computed per sample and displayed as side-by-side boxplots. Mutually exclusive with ``group_by_resolution``. group_by_resolution Column in ``.var`` used to define detection groups. When provided, the plot shows the fraction of groups in which each sample is detected. min_count : int or None, optional Minimum number of non-missing variables within a group for a sample to be considered detected. Mutually exclusive with ``min_fraction``. Only used when ``group_by_resolution`` is provided. min_fraction : float or None, optional Minimum fraction of non-missing variables within a group for a sample to be considered detected. Mutually exclusive with ``min_count``. Only used when ``group_by_resolution`` is provided. fraction_thresh : float or None, optional Completeness fraction threshold in ``[0, 1]``. Drawn as a vertical dashed line on histograms or a horizontal dashed line on boxplots (``group_by_partition``). print_stats : bool, optional Print completeness distribution statistics before plotting. When ``group_by_partition`` is provided, per-group statistics are printed below the global summary. bin_width : float, optional Width of each histogram bin on the fraction axis. Bins span from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01. xlabel_rotation Rotation angle in degrees applied to x-axis tick labels. figsize Tuple ``(width, height)`` controlling figure size in inches. show Display the plot with ``plt.show()`` when True. ax : Axes or None, optional Matplotlib Axes object to plot onto. If ``None``, a new figure and axes are created. save : str or Path or None, optional File path to save the figure. If ``None``, do not save. Returns ------- Axes The Matplotlib Axes object used for plotting. Examples -------- >>> import proteopy as pr >>> adata = pr.datasets.example_peptide_data() >>> pr.pl.completeness_per_sample(adata, fraction_thresh=0.5) With peptide-level proteodata, grouping by ``protein_id`` yields the fraction of proteins detected per sample. >>> pr.pl.completeness_per_sample( ... adata, ... group_by_resolution="protein_id", ... min_count=1, ... ) """ return completeness( adata, axis=1, layer=layer, zero_to_na=zero_to_na, order=order, group_by_partition=group_by_partition, group_by_resolution=group_by_resolution, min_count=min_count, min_fraction=min_fraction, fraction_thresh=fraction_thresh, print_stats=print_stats, bin_width=bin_width, xlabel_rotation=xlabel_rotation, figsize=figsize, show=show, ax=ax, save=save, )
def _contains_value(seq, value) -> bool: """Check if *value* is in *seq*, treating NaN as equal.""" for item in seq: if pd.isna(item) and pd.isna(value): return True if item == value: return True return False def _append_unique(seq, value) -> None: """Append *value* to *seq* only if not already present.""" if not _contains_value(seq, value): seq.append(value) def _n_var_summary_stats(series): """Return a one-row DataFrame of count summary stats.""" return pd.DataFrame({ "mean_count": [series.mean()], "std_count": [series.std()], "median_count": [series.median()], "min_count": [series.min()], "max_count": [series.max()], }) def _add_pct_cols(df, total): """Add percentage columns to *df* in place.""" for col in [ "mean", "std", "median", "min", "max", ]: df[f"{col}_pct"] = ( df[f"{col}_count"] / total * 100 ) def _print_stats_df(df): """Print a DataFrame with one-decimal formatting.""" print(df.to_string( index=False, float_format="%.1f", )) _AGG_STATS = { "mean_count": "mean", "std_count": "std", "median_count": "median", "min_count": "min", "max_count": "max", } def _validate_n_var_per_sample_args( # noqa: C901 adata, level, group_by, order_by, order, layer, ): """Validate inputs for :func:`n_var_per_sample`.""" _, data_level = is_proteodata(adata, raise_error=True) # -- Validate level valid_levels = {"peptide", "protein", None} if level not in valid_levels: raise ValueError( f"Invalid level '{level}'. Must be " "'peptide', 'protein', or None." ) if level == "peptide" and data_level == "protein": raise ValueError( "Cannot count peptides from " "protein-level data." ) # -- Mutual exclusivity if group_by is not None and order_by is not None: raise ValueError( "`group_by` and `order_by` cannot be " "used together." ) # -- Validate layer if layer is None: matrix = adata.X else: if layer not in adata.layers: raise KeyError( f"Layer '{layer}' not found in " "adata.layers." ) matrix = adata.layers[layer] if matrix is None: raise ValueError( "Selected layer is empty; cannot " "compute variable counts." ) # -- Validate group_by column if group_by is not None: if group_by not in adata.obs.columns: raise KeyError( f"Column '{group_by}' not found " "in adata.obs." ) # -- Validate order_by column if order_by is not None: if order_by not in adata.obs.columns: raise KeyError( f"Column '{order_by}' not found " "in adata.obs." ) # -- Validate order elements if order is not None: if group_by is not None: valid = set( adata.obs[group_by].dropna().unique() ) source = f"adata.obs['{group_by}']" elif order_by is not None: valid = set( adata.obs[order_by].dropna().unique() ) source = f"adata.obs['{order_by}']" else: valid = set(adata.obs_names) source = "adata.obs_names" invalid = [ o for o in order if o not in valid ] if invalid: invalid_str = ", ".join( map(str, invalid) ) raise ValueError( f"Unknown value(s) in `order`: " f"{invalid_str}. Valid values " f"come from {source}." ) return data_level, level, matrix def _valid_mask(matrix, zero_to_na): """Return a dense boolean mask of valid (non-missing) entries.""" if sparse.issparse(matrix): arr = matrix.toarray() else: arr = np.asarray(matrix) mask = ~np.isnan(arr) if zero_to_na: mask &= arr != 0 return mask def _n_var_count_per_sample( matrix, zero_to_na, level, data_level, adata, ): """Count non-missing vars per sample. When *level* is ``"protein"`` on peptide-level data, counts unique proteins with at least one non-missing peptide. """ valid = _valid_mask(matrix, zero_to_na) # -- Count at native level if level is None or level == data_level: return valid.sum(axis=1) # -- Protein count from peptide data if level == "protein" and data_level == "peptide": protein_codes, _ = pd.factorize( adata.var["protein_id"].to_numpy(), sort=False, ) n_proteins = protein_codes.max() + 1 # OR-reduce peptide columns into protein columns prot_detected = np.zeros( (valid.shape[0], n_proteins), dtype=bool, ) np.maximum.at( prot_detected, (slice(None), protein_codes), valid, ) return prot_detected.sum(axis=1) raise ValueError( f"Requested level '{level}' is " f"incompatible with " f"'{data_level}' data." ) def _n_var_derive_totals( counts_array, level, data_level, percentage, ylabel, title, adata, ): """Derive totals, percentage, ylabel, and title.""" if level == "protein" and data_level == "peptide": total_vars = adata.var["protein_id"].nunique() else: total_vars = adata.n_vars if percentage: if total_vars == 0: raise ValueError( "Cannot compute percentage: " "no variables found." ) counts_array = ( counts_array / total_vars ) * 100 # -- Resolve y-axis label if ylabel is None: ylabel = "%" if percentage else "#" # -- Resolve title if title is None: if level == "protein" or ( level is None and data_level == "protein" ): entity = "proteins" elif level == "peptide" or ( level is None and data_level == "peptide" ): entity = "peptides" else: entity = "variables" title = f"Number of detected {entity}" return total_vars, counts_array, ylabel, title def _n_var_print_group_stats( counts, stats_df, group_by, total_vars, ): """Print global and per-group statistics.""" global_df = _n_var_summary_stats(counts["count"]) _add_pct_cols(global_df, total_vars) print("Global:") _print_stats_df(global_df) print_df = stats_df.copy() _add_pct_cols(print_df, total_vars) print(f"\nPer {group_by}:") _print_stats_df(print_df) def _n_var_resolve_bar_colors( color_scheme, group_order, stats_df, group_by, ): """Resolve bar colors from a color scheme.""" if color_scheme is None: return None colors = _resolve_color_scheme( color_scheme, group_order, ) if colors is None: return None return [ colors[group_order.index(grp)] for grp in stats_df[group_by] ] def _n_var_group_by_path( counts, adata, group_by, order, color_scheme, total_vars, ylabel, title, print_stats, figsize, xlabel_rotation, save, show, ax=None, ): """Plot mean +/- std bar chart grouped by an obs column.""" group_df = adata.obs[[group_by]].copy() group_df = group_df.rename_axis( "obs", ).reset_index() counts = pd.merge( counts, group_df, on="obs", how="left", ) counts = counts.dropna(subset=[group_by]) if counts.empty: raise ValueError( "No observations remain after " "aligning `group_by` labels.", ) group_values = counts[group_by] if isinstance( group_values.dtype, pd.CategoricalDtype, ): group_values = ( group_values.cat .remove_unused_categories() ) counts[group_by] = group_values available_groups: list[Any] = [] for value in group_values: _append_unique(available_groups, value) group_order = _n_var_resolve_group_order( order, available_groups, group_values, ) # Append any groups not yet in order for value in available_groups: _append_unique(group_order, value) # -- Compute per-group statistics stats_df = ( counts.groupby(group_by, observed=True)[ "count" ] .agg(**_AGG_STATS) .reindex(group_order) ) stats_df = stats_df.dropna( subset=["mean_count"], ) stats_df["std_count"] = ( stats_df["std_count"].fillna(0.0) ) stats_df = stats_df.reset_index() if print_stats: _n_var_print_group_stats( counts, stats_df, group_by, total_vars, ) # -- Plot grouped bar chart bar_colors = _n_var_resolve_bar_colors( color_scheme, group_order, stats_df, group_by, ) if ax is not None: _ax = ax fig = _ax.get_figure() else: fig, _ax = plt.subplots(figsize=figsize) bar_labels = stats_df[group_by].astype(str) _ax.bar( bar_labels, stats_df["mean_count"], yerr=stats_df["std_count"], color=bar_colors, capsize=4.0, edgecolor="black", ) plt.setp( _ax.get_xticklabels(), rotation=xlabel_rotation, ha="right", ) _ax.set_xlabel(group_by) _ax.set_ylabel(ylabel) fig.suptitle(title, y=0.95) plt.tight_layout() if save is not None: fig.savefig( save, dpi=300, bbox_inches="tight", ) if show: plt.show() return _ax def _n_var_resolve_group_order( order, available_groups, group_values, ): """Resolve group ordering from order arg or categories.""" if order: # Deduplicate while preserving order group_order: list[Any] = [] for grp in order: if not _contains_value( group_order, grp, ): group_order.append(grp) return group_order if isinstance( group_values.dtype, pd.CategoricalDtype, ): return list( group_values.cat.categories, ) return available_groups.copy() def _n_var_resolve_obs_ordering( counts, obs_df, group_key, order, available_groups, ascending, ): """Resolve observation ordering for the per-obs bar path.""" has_grouping = group_key != "_group" if has_grouping: group_order = _n_var_resolve_group_order( order, available_groups, obs_df[group_key], ) for grp in available_groups: _append_unique(group_order, grp) cat_index_map: dict[str, list[str]] = {} for grp in group_order: obs_list = obs_df.loc[ obs_df[group_key] == grp, "obs" ].tolist() if obs_list: cat_index_map[str(grp)] = obs_list x_ordered = [ obs for obs_list in cat_index_map.values() for obs in obs_list ] else: if order: # Deduplicate, then append remaining obs x_ordered: list[Any] = [] for obs_name in order: _append_unique( x_ordered, obs_name, ) for obs_name in counts["obs"]: _append_unique( x_ordered, obs_name, ) else: if ascending is not None: sorted_counts = counts.sort_values( "count", ascending=ascending, kind="mergesort", ) x_ordered = sorted_counts[ "obs" ].tolist() else: x_ordered = counts[ "obs" ].tolist() cat_index_map = {"all": x_ordered} return x_ordered, cat_index_map def _n_var_plot_per_obs( counts, x_ordered, cat_index_map, group_key, order_by, total_vars, color_scheme, ylabel, title, print_stats, figsize, xlabel_rotation, order_by_label_rotation, save, show, ax=None, ): """Plot per-observation bars with group labels.""" has_grouping = group_key != "_group" # -- Print statistics if print_stats: if has_grouping: global_df = _n_var_summary_stats( counts["count"], ) _add_pct_cols(global_df, total_vars) print("Global:") _print_stats_df(global_df) print_df = ( counts.groupby( order_by, observed=True, )["count"] .agg(**_AGG_STATS) .reset_index() ) _add_pct_cols(print_df, total_vars) print(f"\nPer {order_by}:") _print_stats_df(print_df) else: print_df = _n_var_summary_stats( counts["count"], ) _add_pct_cols(print_df, total_vars) _print_stats_df(print_df) # -- Resolve colors counts[group_key] = ( counts[group_key].astype(str) ) unique_groups = list(cat_index_map.keys()) colors = _resolve_color_scheme( color_scheme, unique_groups, ) plot_kwargs = {} if colors is not None: color_map = { str(grp): colors[i] for i, grp in enumerate(unique_groups) } plot_kwargs["color"] = ( counts[group_key].map(color_map).to_list() ) # -- Plot per-observation bars if ax is not None: _ax = ax fig = _ax.get_figure() else: fig, _ax = plt.subplots(figsize=figsize) counts.plot( kind="bar", x="obs", y="count", ax=_ax, legend=False, **plot_kwargs, ) plt.setp( _ax.get_xticklabels(), rotation=xlabel_rotation, ha="right", ) _ax.set_xlabel("") _ax.set_ylabel(ylabel) # -- Add group labels above bars obs_idx_map = { obs: i for i, obs in enumerate(x_ordered) } ymax = counts['count'].max() for cat, obs_list in cat_index_map.items(): if not obs_list: continue start_idx = obs_idx_map[obs_list[0]] end_idx = obs_idx_map[obs_list[-1]] mid_idx = (start_idx + end_idx) / 2 _ax.text( x=mid_idx, y=ymax * 1.05, s=cat, ha='center', va='bottom', fontsize=8, fontweight='bold', rotation=order_by_label_rotation, ) fig.suptitle(title, y=0.95) plt.tight_layout() if save is not None: fig.savefig( save, dpi=300, bbox_inches='tight', ) if show: plt.show() return _ax def n_var_per_sample( adata: ad.AnnData, *, layer: str | None = None, zero_to_na: bool = False, level: str | None = None, percentage: bool = False, ascending: bool | None = None, order_by: str | None = None, order: Sequence[str] | None = None, group_by: str | None = None, print_stats: bool = False, figsize: tuple[float, float] = (6.0, 4.0), color_scheme: str | dict | Sequence | Colormap | callable | None = None, title: str | None = None, ylabel: str | None = None, xlabel_rotation: float = 90, order_by_label_rotation: float = 0, show: bool = True, ax: Axes | None = None, save: str | Path | None = None, ) -> Axes: """ Plot the number of detected variables (peptides or protein) per sample. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` object in proteodata format. layer : str or None, optional Key in ``adata.layers``; when set, uses that layer instead of ``.X``. zero_to_na : bool, optional If ``True``, zeros in the matrix are treated as missing values. level : str or None, optional ``"peptide"`` counts detected peptides. ``"protein"`` counts detected proteins. ``None`` follows the intrinsic level of the data (.vars). percentage : bool, optional Display y-axis values as a percentage of total variables instead of raw counts. ascending : bool or None, optional Sort observations by detected counts. ``True`` places lower counts to the left; ``False`` places higher counts to the left; ``None`` preserves the existing observation order. order_by : str or None, optional Column in ``adata.obs`` used for grouping and colouring bars. order : Sequence[str] or None, optional Controls ordering and subsetting on the x-axis. Without ``group_by`` or ``order_by`` it lists observation names. With ``order_by`` it specifies the group order. With ``group_by`` it specifies the group order for the bar chart. group_by : str or None, optional Column in ``adata.obs`` used to summarise observations into groups. When provided, a mean +/- std bar chart is shown. Mutually exclusive with ``order_by``. print_stats : bool, optional Print summary statistics as a DataFrame. figsize : tuple of float, optional Figure size ``(width, height)`` in inches passed to :func:`matplotlib.pyplot.subplots`. color_scheme Colour mapping for groups. Accepts a named Matplotlib colormap, a single colour, a list/tuple of colours, a dict mapping labels to colours, a :class:`~matplotlib.colors.Colormap`, or a callable. title : str or None, optional Plot title. ylabel : str or None, optional Label for the y-axis. xlabel_rotation : float, optional Rotation in degrees applied to x-axis tick labels. order_by_label_rotation : float, optional Rotation in degrees applied to group labels drawn above the plot. show : bool, optional Call :func:`matplotlib.pyplot.show` when ``True``. ax : Axes or None, optional Matplotlib Axes to plot onto. If ``None``, a new figure and axes are created. save : str or Path or None, optional File path to save the figure. Returns ------- Axes The Matplotlib Axes object used for plotting. Examples -------- >>> import proteopy as pr >>> adata = pr.datasets.karayel_2020() >>> pr.pl.n_var_per_sample(adata) Show mean +/- std per group: >>> pr.pl.n_var_per_sample( ... adata, ... group_by="cell_type", ... ) Order bars by a grouping column: >>> pr.pl.n_var_per_sample( ... adata, ... order_by="cell_type", ... order=["LBaso", "Ortho"], ... ) """ data_level, level, matrix = ( _validate_n_var_per_sample_args( adata, level, group_by, order_by, order, layer, ) ) # -- Count non-missing vars per sample counts_array = _n_var_count_per_sample( matrix, zero_to_na, level, data_level, adata, ) # -- Derive totals, percentage, ylabel, and title total_vars, counts_array, ylabel, title = ( _n_var_derive_totals( counts_array, level, data_level, percentage, ylabel, title, adata, ) ) # -- Build counts DataFrame counts_series = pd.Series( counts_array, index=adata.obs_names, name="count", ) counts = counts_series.rename_axis( "obs", ).reset_index() # -- Warn when ascending has no effect if ascending is not None: if group_by is not None: warnings.warn( "`ascending` is ignored when " "`group_by` is set.", UserWarning, stacklevel=2, ) elif order is not None: warnings.warn( "`ascending` is ignored when " "`order` is set explicitly.", UserWarning, stacklevel=2, ) # -- group_by path: mean +/- std bar plot per group if group_by is not None: return _n_var_group_by_path( counts, adata, group_by, order, color_scheme, total_vars, ylabel, title, print_stats, figsize, xlabel_rotation, save, show, ax, ) # -- Per-observation bar plot (with optional order_by) has_grouping = order_by is not None group_key = ( order_by if has_grouping else "_group" ) # Attach grouping column to counts if has_grouping: if group_key != "obs": obs = adata.obs[[group_key]].copy() obs = obs.rename_axis( "obs", ).reset_index() counts = pd.merge( counts, obs, on="obs", how="left", ) else: counts[group_key] = counts["obs"] else: counts[group_key] = "all" obs_df = adata.obs.copy() obs_df = obs_df.rename_axis( "obs", ).reset_index() if group_key not in obs_df.columns: obs_df[group_key] = "all" if has_grouping and isinstance( obs_df[group_key].dtype, pd.CategoricalDtype, ): obs_df[group_key] = ( obs_df[group_key].astype("category") ) available_groups: list[Any] = [] for value in obs_df[group_key]: _append_unique(available_groups, value) # -- Resolve observation ordering x_ordered, cat_index_map = ( _n_var_resolve_obs_ordering( counts, obs_df, group_key, order, available_groups, ascending, ) ) counts["obs"] = pd.Categorical( counts["obs"], categories=x_ordered, ordered=True, ) counts = counts.sort_values("obs") # -- Plot per-observation bars return _n_var_plot_per_obs( counts, x_ordered, cat_index_map, group_key, order_by, total_vars, color_scheme, ylabel, title, print_stats, figsize, xlabel_rotation, order_by_label_rotation, save, show, ax, ) n_peptides_per_sample = partial_with_docsig( n_var_per_sample, level="peptide", docstr_header="""\ Plot the number of detected peptides per sample. For each sample (observation), counts the number of peptides with non-missing values. Requires peptide-level proteodata.""", docstr_examples="""\ >>> import proteopy as pr >>> adata = pr.datasets.williams_2018() >>> pr.pl.n_peptides_per_sample(adata) Show mean +/- std per group: >>> pr.pl.n_peptides_per_sample( ... adata, ... group_by="tissue", ... )""", ) n_proteins_per_sample = partial_with_docsig( n_var_per_sample, level="protein", docstr_header="""\ Plot the number of detected proteins per sample. For each sample (observation), counts the number of proteins with non-missing values. """, docstr_examples="""\ >>> import proteopy as pr Protein-level data: >>> adata = pr.datasets.karayel_2020() >>> pr.pl.n_proteins_per_sample(adata) Peptide-level data (aggregated to proteins): >>> adata = pr.datasets.williams_2018() >>> pr.pl.n_proteins_per_sample(adata) Show mean +/- std per group: >>> pr.pl.n_proteins_per_sample( ... adata, ... group_by="tissue", ... )""", )
[docs] def n_samples_per_category( adata: ad.AnnData, category_key: str | Sequence[str], categories: Sequence[Any] | None = None, ignore_na: bool = False, ascending: bool = False, order: Sequence[Any] | None = None, xlabel_rotation: float = 45.0, color_scheme: Any | None = None, figsize: tuple[float, float] = (6.0, 4.0), show: bool = True, save: str | Path | None = None, ax: bool = False, ) -> Axes | None: """ Plot sample (obs) counts per category (optionally stratified). Parameters ---------- adata : anndata.AnnData Annotated data matrix with categorical obs annotations. category_key : str | Sequence[str] One or two column names in ``adata.obs`` used to stratify observations. categories : Sequence[Any] | None Labels from the first category column to display on the x-axis. Rows whose first-column value is not listed are dropped. ignore_na : bool Drop observations with missing labels when ``True``; otherwise, missing values are shown as ``"missing"``. ascending : bool Sort categories by total counts when no explicit order is supplied. ``True`` places lower counts on the left. order : Sequence[Any] | None Explicit order for the x-axis labels (values of the first category column). Any levels not listed are appended afterwards in their intrinsic order. When provided, ``ascending`` is ignored. xlabel_rotation : float Rotation angle (degrees) applied to the x-axis tick labels. color_scheme : Any | None Mapping, sequence, colormap name, or callable used to colour categories. figsize : tuple[float, float] Figure size (width, height) in inches used for :func:`matplotlib.pyplot.subplots`. show : bool Call :func:`matplotlib.pyplot.show` when ``True``. save : str | Path | None Save the figure to the provided path (``str`` or :class:`~pathlib.Path``). ax : bool Return the :class:`~matplotlib.axes.Axes` instead of displaying the plot. """ check_proteodata(adata) if isinstance(category_key, str): category_cols = [category_key] else: category_cols = list(category_key) if not category_cols: raise ValueError("category_key must contain at least one column name.") missing_label = "missing" unknown_cols = [col for col in category_cols if col not in adata.obs] if unknown_cols: raise KeyError( "Column(s) missing in adata.obs: " f"{', '.join(map(str, unknown_cols))}." ) obs = adata.obs.loc[:, category_cols].copy() for col in category_cols: if not (is_string_dtype(obs[col]) or is_categorical_dtype(obs[col])): obs[col] = obs[col].astype("string") if ignore_na: continue if is_categorical_dtype(obs[col]): if missing_label not in obs[col].cat.categories: obs[col] = obs[col].cat.add_categories([missing_label]) obs[col] = obs[col].fillna(missing_label) else: obs[col] = obs[col].fillna(missing_label) first_cat_col = category_cols[0] if ignore_na: obs = obs.dropna(subset=category_cols, how="any") first_cat_col = category_cols[0] selected_categories: list[Any] | None = None if categories is not None: if isinstance(categories, (str, bytes)): selected_categories = [categories] else: selected_categories = list(categories) if not selected_categories: raise ValueError("categories must contain at least one label.") mask = obs[first_cat_col].isin(selected_categories) if not mask.any(): raise ValueError("No observations match the requested categories.") obs = obs.loc[mask].copy() if obs.empty: raise ValueError("No observations available after NA handling.") for col in category_cols: if is_categorical_dtype(obs[col]): obs[col] = obs[col].cat.remove_unused_categories() def _ordered_categories(series: pd.Series) -> list[Any]: if is_categorical_dtype(series): ordered = list(series.cat.categories) else: ordered = list(pd.unique(series)) if not ignore_na and missing_label in ordered: ordered = [ value for value in ordered if value != missing_label ] + [missing_label] return ordered first_level_order = _ordered_categories(obs[first_cat_col]) if selected_categories is not None: first_level_order = [ category for category in selected_categories if category in first_level_order ] if order is not None: if isinstance(order, str): specified = [order] else: specified = list(order) unknown_specified = [cat for cat in specified if cat not in first_level_order] if unknown_specified: raise ValueError( "Order values not present in the first category column: " f"{', '.join(map(str, unknown_specified))}." ) remaining = [cat for cat in first_level_order if cat not in specified] first_level_order = specified + remaining use_count_sort = order is None and selected_categories is None fig, _ax = plt.subplots(figsize=figsize) if len(category_cols) == 1: freq = obs[first_cat_col].value_counts(dropna=False) if use_count_sort: freq = freq.sort_values(ascending=ascending) else: freq = freq.reindex(first_level_order, fill_value=0) plot_kwargs: dict[str, Any] = {} if color_scheme is not None: colors = _resolve_color_scheme(color_scheme, freq.index) if colors is not None: plot_kwargs["color"] = colors freq.plot(kind="bar", ax=_ax, **plot_kwargs) elif len(category_cols) == 2: second_cat_col = category_cols[1] second_level_order = _ordered_categories(obs[second_cat_col]) df = ( obs.groupby(category_cols, observed=False) .size() .unstack(fill_value=0) ) df = df.reindex(first_level_order, fill_value=0) df = df.reindex(columns=second_level_order, fill_value=0) if use_count_sort: df = df.loc[df.sum(axis=1).sort_values(ascending=ascending).index] colors = _resolve_color_scheme(color_scheme, df.columns) plot_kwargs: dict[str, Any] = {} if colors is not None: plot_kwargs["color"] = colors df.plot(kind="bar", stacked=True, ax=_ax, **plot_kwargs) if df.shape[1] > 1: _ax.legend(loc="center right", bbox_to_anchor=(1.4, 0.5)) else: raise NotImplementedError( "Plotting more than two category columns is not implemented." ) _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) _ax.set_xlabel(first_cat_col) _ax.set_ylabel('#') ha = ( 'right' if xlabel_rotation > 0 else 'left' if xlabel_rotation < 0 else 'center' ) plt.setp(_ax.get_xticklabels(), rotation=xlabel_rotation, ha=ha) fig.tight_layout() save_path: Path | None = Path(save) if save is not None else None if save_path is not None: fig.savefig(save_path, dpi=300, bbox_inches="tight") if show: plt.show() if ax: return _ax if not show and save_path is None and not ax: warnings.warn( "Function does not do anything. Enable `show`, provide a `save` path, " "or set `ax=True`." ) plt.close(fig)
[docs] def n_cat1_per_cat2_hist( adata: ad.AnnData, first_category: str, second_category: str, axis: int, bin_width: float | None = None, bin_range: tuple[float, float] | None = None, print_stats: bool = False, figsize: tuple[float, float] = (6.0, 4.0), show: bool = True, save: str | Path | None = None, ax: Axes | None = None, ) -> Axes: """ Plot the distribution of the number of first-category entries per second category. Parameters ---------- adata : AnnData Annotated data matrix. first_category : str Column providing the secondary category from the same axis as ``second_category``. Pass ``"index"`` to use ``adata.obs_names`` (``axis == 0``) or ``adata.var_names`` (``axis == 1``). second_category : str Column name identifying the primary category. Resolved from ``adata.obs`` when ``axis == 0`` and ``adata.var`` when ``axis == 1``. Passing ``"index"`` is not supported. axis : int ``0`` to work on ``adata.obs``, ``1`` to work on ``adata.var``. bin_width : float | None Optional histogram bin width. Must be positive when provided. bin_range : tuple[float, float] | None Optional tuple ``(lower, upper)`` limiting the histogram bins. ``lower`` must be strictly smaller than ``upper``. print_stats : bool Print distribution statistics (mean, median, mode, variance, min, max). figsize : tuple[float, float] Size (width, height) in inches passed to :func:`matplotlib.pyplot.subplots`. show : bool Call :func:`matplotlib.pyplot.show` when ``True``. save : str | Path | None Save the figure to the provided path when given. ax : Axes | None Matplotlib Axes to plot onto. If ``None``, a new figure and axes are created. """ check_proteodata(adata) # Ensures that the 'index' has unique values if used if axis not in (0, 1): raise ValueError("axis must be either 0 (.obs) or 1 (.var).") frame = adata.obs if axis == 0 else adata.var frame_label = ".obs" if axis == 0 else ".var" if second_category == "index": raise ValueError( "`second_category='index'` is not supported; pass 'index' via " "`first_category` instead." ) if second_category not in frame: raise KeyError( f"Column '{second_category}' not found in adata{frame_label}." ) if first_category != "index" and first_category not in frame: raise KeyError( f"Column '{first_category}' not found in adata{frame_label}." ) if bin_width is not None: if bin_width <= 0: raise ValueError("bin_width must be a positive number.") if bin_range is not None: if ( not isinstance(bin_range, tuple) or len(bin_range) != 2 or not all(np.isfinite(bin_range)) ): raise TypeError( "bin_range must be a tuple of two finite numbers (lower, upper)." ) lower, upper = bin_range if lower >= upper: raise ValueError("bin_range lower bound must be less than upper bound.") temp_col = "__proteopy_axis_index__" if first_category == "index" else first_category data = frame[[second_category]].copy() if first_category == "index": index_values = adata.obs_names if axis == 0 else adata.var_names data[temp_col] = index_values else: data[temp_col] = frame[first_category] data = data.drop_duplicates(subset=[second_category, temp_col]) counts = data.groupby(second_category, observed=False).size() if counts.empty: raise ValueError( "No entries available to compute counts for the requested categories." ) if bin_width is None: edges = np.histogram_bin_edges(counts.values, bins="auto") auto_width = edges[1] - edges[0] bin_width = max(auto_width, 1.0) if print_stats: stats_df = pd.DataFrame( { "mean": [counts.mean()], "median": [counts.median()], "mode": [counts.mode().iloc[0]], "variance": [counts.var()], "min": [counts.min()], "max": [counts.max()], } ) print(stats_df.to_string(index=False)) if ax is None: fig, _ax = plt.subplots(figsize=figsize) else: _ax = ax fig = _ax.get_figure() if first_category == "index": entry_label = "observations" if axis == 0 else "variables" else: entry_label = first_category sns.histplot( counts, binwidth=bin_width, binrange=bin_range, ax=_ax, ) _ax.set_xlabel(f"Number of {entry_label} per {second_category}") _ax.set_ylabel(f"# {second_category}") fig.tight_layout() if save is not None: fig.savefig(save, dpi=300, bbox_inches="tight") if show: plt.show() return _ax
docstr_header = ( "Plot the distribution of the number of first-category entries per second category." ) n_peptides_per_protein = partial_with_docsig( n_cat1_per_cat2_hist, first_category="peptide_id", second_category="protein_id", axis=1, docstr_header=docstr_header, ) n_proteoforms_per_protein = partial_with_docsig( n_cat1_per_cat2_hist, first_category="proteoform_id", second_category="protein_id", axis=1, docstr_header=docstr_header, )
[docs] def cv_by_group( adata: ad.AnnData, group_by: str, layer: str | None = None, zero_to_na: bool = False, min_samples: int = None, force: bool = False, order: list | None = None, color_scheme=None, alpha: float = 0.8, hline: float | None = None, show_points: bool = False, point_alpha: float = 0.7, point_size: float = 1, xlabel_rotation: int | float = 0, figsize: tuple[float, float] = (6, 4), show: bool = True, ax: bool = False, save: str | None = None, print_stats: bool = False, ): """ Compute per-group coefficients of variation and plot their distributions. Parameters ---------- adata : AnnData AnnData object that contains proteomics quantifications. group_by : str Column in ``adata.obs`` used to define observation groups for CV calculation. layer : str | None, optional AnnData layer to read intensities from. Defaults to ``adata.X``. zero_to_na : bool, optional Replace zero values with NaN before computing CVs. Default is ``False``. min_samples : int | None, optional Minimum number of observations per variable required to compute a CV. Variables with fewer non-NaN entries receive NaN. Default is ``3``. Ignored when using precomputed CV data from ``adata.varm``. force : bool, optional Force recomputation of CV values even if precomputed data exists in ``adata.varm``. When ``True``, uses a temporary slot that is deleted after extracting the data. Default is ``False``. order : list | None, optional Explicit order of group labels (without the ``cv_`` prefix) along the x-axis. When ``None`` the observed group order is used. color_scheme : sequence, dict | None, optional Color assignments for groups. When None, uses the Matplotlib default color cycle. alpha : float, optional Transparency for the violin bodies. Default is ``0.8``. hline : float | None, optional If set, draw a horizontal dashed line at this CV value. show_points : bool, optional Overlay individual variable CVs as a strip plot. Default is ``False``. point_alpha : float, optional Opacity for individual points when ``show_points`` is ``True``. point_size : float, optional Size of the individual CV points. Default is ``1``. xlabel_rotation : float, optional Rotation angle (degrees) for the x-axis group labels. figsize : tuple of float, optional Matplotlib figure size in inches. Default is ``(6, 4)``. show : bool, optional Call ``plt.show()`` when ``True``. Default is ``True``. ax : bool, optional Return the Matplotlib Axes if ``True``. save : str | None, optional Path to save the figure. When ``None`` the figure is not saved. print_stats : bool, optional Print CV summary statistics. """ check_proteodata(adata) if group_by not in adata.obs.columns: raise KeyError(f"Column '{group_by}' not found in adata.obs.") if adata.n_obs == 0: raise ValueError( "AnnData object contains no observations; cannot compute CVs." ) groups = adata.obs[group_by] if groups.dropna().empty: raise ValueError( f"Column '{group_by}' does not contain any non-missing group labels." ) if isinstance(groups.dtype, pd.CategoricalDtype): observed_groups = groups.cat.remove_unused_categories().cat.categories unique_groups = [str(cat) for cat in observed_groups] else: unique_groups = pd.Index(groups.astype(str)).unique().tolist() if not unique_groups: raise ValueError( f"Column '{group_by}' does not contain any finite groups." ) # Use existing CV data if available; otherwise compute temporarily layer_suffix = sanitize_string(layer) if layer is not None else "X" varm_key = f"cv_by_{sanitize_string(group_by)}_{layer_suffix}" key_existed = varm_key in adata.varm temp_key_name = None # Determine whether to use precomputed data or compute new use_precomputed = key_existed and not force if use_precomputed: # Check if min_samples was explicitly provided if min_samples: raise ValueError( f"Cannot use `min_samples={min_samples}` with precomputed CV " f"data in adata.varm['{varm_key}']. Either:\n" f" - Use `force=True` to recompute CV values with the new " f"`min_samples` setting, or\n" f" - Remove the precomputed data with " f"`del adata.varm['{varm_key}']` before calling this function." ) print(f"Using existing CV data from adata.varm['{varm_key}'].") key_to_use = varm_key else: # Random key prevents overwriting existing varm slots temp_key_name = f"_temp_cv_{uuid.uuid4().hex[:8]}" default_min_samples = 3 min_samples = min_samples or default_min_samples calculate_cv( adata, group_by=group_by, layer=layer, zero_to_na=zero_to_na, min_samples=min_samples, key_added=temp_key_name, inplace=True, ) key_to_use = temp_key_name if key_to_use not in adata.varm: raise RuntimeError( f"Failed to compute CV data: adata.varm['{key_to_use}'] not found." ) check_proteodata(adata) cv_df = adata.varm[key_to_use].copy() # Clean up temporary data immediately after extraction if temp_key_name is not None: del adata.varm[temp_key_name] df_melted = cv_df.melt(var_name="Group", value_name="CV", ignore_index=False) df_melted = df_melted.reset_index(drop=True) if order is None: order = unique_groups else: missing = [grp for grp in order if grp not in df_melted["Group"].unique()] if missing: raise ValueError( "Requested ordering includes groups with no CV data: " f"{', '.join(missing)}." ) resolved_colors = _resolve_color_scheme(color_scheme, order) if resolved_colors is None: palette = None else: palette = dict(zip(order, resolved_colors)) if print_stats: cv_values = df_melted["CV"].dropna() global_summary = pd.DataFrame({ "Count": [cv_values.count()], "Min": [round(cv_values.min(), 4)], "Max": [round(cv_values.max(), 4)], "Median": [round(cv_values.median(), 4)], "Mean": [round(cv_values.mean(), 4)], "Std": [round(cv_values.std(), 4)], }) print("Global CV Summary:") print(global_summary.to_string(index=False)) print() per_group = ( df_melted.groupby("Group")["CV"] .agg( Count="count", Min="min", Max="max", Median="median", Mean="mean", Std="std", ) .round(4) .reindex(order) ) print("Per-Group CV Summary:") print(per_group.to_string()) print() if hline is not None: below_count = (cv_values < hline).sum() total_count = cv_values.count() pct = ( round(below_count / total_count * 100, 4) if total_count > 0 else 0.0 ) global_thresh = pd.DataFrame({ "Count below": [int(below_count)], "Percentage below": [pct], }) print( f"Global Threshold Summary " f"(hline={hline}):" ) print(global_thresh.to_string(index=False)) print() def _thresh_stats(group_cv): n_below = (group_cv < hline).sum() n_total = group_cv.count() pct_below = ( round(n_below / n_total * 100, 4) if n_total > 0 else 0.0 ) return pd.Series({ "Count below": int(n_below), "Percentage below": pct_below, }) per_group_thresh = ( df_melted.groupby("Group")["CV"] .apply(_thresh_stats) .unstack() .reindex(order) ) print( f"Per-Group Threshold Summary " f"(hline={hline}):" ) print(per_group_thresh.to_string()) print() fig, ax_plot = plt.subplots(figsize=figsize, dpi=150) sns.violinplot( data=df_melted, x="Group", y="CV", hue="Group", order=order, palette=palette, cut=0, inner="box", alpha=alpha, legend=False, ax=ax_plot, ) # Optionally overlay points if show_points: sns.stripplot( data=df_melted, x="Group", y="CV", order=order, color="black", alpha=point_alpha, size=point_size, jitter=0.2, dodge=False, ax=ax_plot, ) # Optional horizontal dashed line if hline is not None: ax_plot.axhline( y=hline, color="black", linestyle="--", linewidth=1, alpha=0.8, ) # add annotation for clarity ax_plot.text( x=-0.4, y=hline, s=f"{hline:.2f}", color="black", va="bottom", ha="left", fontsize=8, ) ax_plot.set_xlabel("") ax_plot.set_ylabel("Coefficient of Variation (CV)") for label in ax_plot.get_xticklabels(): label.set_rotation(xlabel_rotation) ax_plot.set_title("Distribution of CV across groups") sns.despine() plt.tight_layout() check_proteodata(adata) if save: fig.savefig(save, dpi=300, bbox_inches="tight") print(f"Figure saved to: {save}") if show: plt.show() if ax: return ax_plot
[docs] def sample_correlation_matrix( adata: ad.AnnData, method: str = "pearson", zero_to_na: bool = False, layer: str | None = None, fill_na: float | None = None, margin_color: str | None = None, color_scheme=None, cmap: str = "coolwarm", linkage_method: str = "average", xticklabels: bool = False, yticklabels: bool = False, figsize: tuple[float, float] = (9.0, 7.0), show: bool = True, ax: bool = False, print_stats: bool = False, save: str | Path | None = None, ) -> Axes | None: """ Plot a clustered correlation heatmap across samples (obs). Parameters ---------- adata : AnnData :class:`~anndata.AnnData` with proteomics annotations. method : str Correlation estimator passed to :meth:`pandas.DataFrame.corr`. zero_to_na : bool Replace zeros with missing values before computing correlations. layer : str | None Optional ``adata.layers`` key to draw quantification values from. When ``None`` the primary matrix ``adata.X`` is used. fill_na : float | None Constant used to replace remaining ``NaN`` values prior to correlation. When ``None`` (default), a :class:`ValueError` is raised if missing values are detected (suggesting ``fill_na=0``). margin_color : str | None Optional column in ``adata.obs`` used to color dendrogram labels. color_scheme : Any Color palette specification understood by :func:`proteopy.utils.matplotlib._resolve_color_scheme`. cmap : str Continuous colormap for the heatmap body. linkage_method : str Linkage criterion handed to :func:`scipy.cluster.hierarchy.linkage`. xticklabels, yticklabels : bool Whether to show x- and y-axis tick labels. figsize : tuple[float, float] Matplotlib figure size in inches. show : bool Display the figure with :func:`matplotlib.pyplot.show`. ax : bool Return the heatmap :class:`matplotlib.axes.Axes` when ``True``. print_stats : bool Print correlation summary statistics before drawing the plot. Includes overall off-diagonal statistics, per-sample mean correlation, and per-group correlations when ``margin_color`` is provided. save : str | Path | None File path for saving the Seaborn cluster map. When ``None`` nothing is written. Returns ------- Axes or None Heatmap axes when ``ax`` is ``True``; otherwise ``None``. Raises ------ ValueError If the selected matrix still contains missing values after optional zero replacement and ``fill_na`` is ``None``. """ check_proteodata(adata) # ---- values from adata.X or a specified layer (obs × var) expected_shape = (adata.n_obs, adata.n_vars) if layer is None: matrix = adata.X else: if layer not in adata.layers: raise KeyError(f"Layer '{layer}' not found in adata.layers.") matrix = adata.layers[layer] if matrix is None: raise ValueError("Selected matrix is empty; cannot compute correlations.") if matrix.shape != expected_shape: raise ValueError( "Selected matrix shape " f"{matrix.shape} does not match adata dimensions {expected_shape}." ) if isinstance(matrix, pd.DataFrame): vals = matrix.reindex(index=adata.obs_names, columns=adata.var_names).copy() else: if sparse.issparse(matrix): # correlation requires dense values; convert temporarily dense_matrix = matrix.toarray() else: dense_matrix = np.asarray(matrix) vals = pd.DataFrame( dense_matrix, index=adata.obs_names, columns=adata.var_names, ) if zero_to_na: vals = vals.replace(0, np.nan) if fill_na is not None: vals = vals.fillna(fill_na) if vals.isna().to_numpy().any(): raise ValueError( "Input matrix contains missing values; provide `fill_na` (e.g., " "`fill_na=0`) to replace them before computing correlations." ) # ---- obs×obs correlation (pairwise complete) corr_df = vals.T.corr(method=method) # (obs × obs) corr_df.index = adata.obs_names corr_df.columns = adata.obs_names # ---- compute off-diagonal mean for color center A = corr_df.values.astype(float, copy=False) n = A.shape[0] if n > 1: offdiag = A[~np.eye(n, dtype=bool)] center_val = np.nanmean(offdiag) else: center_val = float(np.nanmean(A)) # degenerate case # ---- optional row/col colors from obs[margin_color] row_colors = None legend_handles = None if margin_color is not None: if margin_color not in adata.obs.columns: raise KeyError(f"Column '{margin_color}' not found in adata.obs.") groups = adata.obs.loc[corr_df.index, margin_color] cats = pd.Categorical(groups.dropna()).categories resolved_colors = _resolve_color_scheme(color_scheme, cats) if resolved_colors is None: resolved_colors = ( sns.color_palette(n_colors=len(cats)) if len(cats) > 0 else [] ) palette = {str(cat): color for cat, color in zip(cats, resolved_colors)} groups_str = groups.astype("string") row_color_series = groups_str.map(palette) missing_mask = row_color_series.isna() & groups.notna() if missing_mask.any(): missing_cats = sorted(groups[missing_mask].astype(str).unique()) raise ValueError( "No color provided for categories: " f"{', '.join(missing_cats)} in '{margin_color}'." ) legend_handles = [ Patch(facecolor=palette[str(cat)], edgecolor="none", label=str(cat)) for cat in cats ] if groups.isna().any(): na_color = mpl.colors.to_rgba("lightgray") row_color_series = row_color_series.astype(object) row_color_series[groups.isna()] = na_color legend_handles.append( Patch(facecolor=na_color, edgecolor="none", label="NA") ) row_colors = ( row_color_series.to_numpy() if row_color_series is not None else None ) # ---- hierarchical clustering on (1 - r) dist = 1 - corr_df.values np.fill_diagonal(dist, 0.0) dist = np.clip(dist, 0, 2) # numerical guard Z = linkage(squareform(dist), method=linkage_method) # ---- optional statistics printout if print_stats and n > 1: # 1) Overall off-diagonal summary summary = pd.DataFrame({ "min": [np.nanmin(offdiag)], "max": [np.nanmax(offdiag)], "mean": [np.nanmean(offdiag)], "median": [np.nanmedian(offdiag)], "std": [np.nanstd(offdiag)], }) print( f"Sample correlation summary " f"(off-diagonal, {method}):" ) print(summary.to_string(index=False)) print() # 2) Per-sample mean correlation (dendrogram order) mask = ~np.eye(n, dtype=bool) per_sample_mean = np.nanmean( np.where(mask, A, np.nan), axis=1 ) heatmap_order = leaves_list(Z) per_sample_df = pd.DataFrame({ "sample_id": corr_df.index[heatmap_order], "mean_corr": per_sample_mean[heatmap_order], }) print("Per-sample mean correlation:") print(per_sample_df.to_string(index=False)) print() # 3) Per-group correlation (if margin_color provided) if margin_color is not None: if margin_color not in adata.obs.columns: raise KeyError( f"Column '{margin_color}' not found " f"in adata.obs." ) groups_ps = adata.obs.loc[ corr_df.index, margin_color ] unique_groups = groups_ps.dropna().unique() group_rows = [] for grp in sorted(unique_groups): grp_idx = groups_ps[ groups_ps == grp ].index other_idx = groups_ps[ (groups_ps != grp) & groups_ps.notna() ].index within = corr_df.loc[grp_idx, grp_idx] within_vals = within.values[ ~np.eye(len(grp_idx), dtype=bool) ] mean_within = ( np.nanmean(within_vals) if len(within_vals) > 0 else np.nan ) if len(other_idx) > 0: between_vals = corr_df.loc[ grp_idx, other_idx ].values.ravel() mean_between = np.nanmean( between_vals ) else: mean_between = np.nan group_rows.append({ "group": grp, "mean_within": mean_within, "mean_between": mean_between, }) group_df = pd.DataFrame(group_rows) print("Per-group mean correlation:") print(group_df.to_string(index=False)) print() # ---- clustermap (center at off-diagonal mean) g = sns.clustermap( corr_df, row_linkage=Z, col_linkage=Z, row_colors=row_colors, col_colors=row_colors if row_colors is not None else None, cmap=cmap, center=center_val, figsize=figsize, xticklabels=xticklabels, yticklabels=yticklabels, cbar_kws={"label": f"{method.capitalize()}"}, ) # ---- add legend for margin_color colors if legend_handles is not None: g.ax_heatmap.legend( handles=legend_handles, title=margin_color, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., frameon=False, ) g.ax_heatmap.set_xlabel("Samples") g.ax_heatmap.set_ylabel("Samples") plt.tight_layout() if show: plt.show() if save: g.savefig(save, dpi=300, bbox_inches="tight") if ax: return g.ax_heatmap
[docs] def hclustv_profiles_heatmap( adata: ad.AnnData, selected_vars: list[str] | None = None, group_by: str | None = None, summary_method: str = "median", linkage_method: str = "average", distance_metric: str = "euclidean", layer: str | None = None, zero_to_na: bool = False, fill_na: float | int | None = None, skip_na: bool = True, cmap: str = "coolwarm", margin_color: bool = False, order_by: str | None = None, order: str | list | None = None, color_scheme: str | dict | Sequence | Colormap | None = None, row_cluster: bool = True, col_cluster: bool = True, cbar_pos: tuple[float, float, float, float] | None = ( 0.02, 0.8, 0.05, 0.18 ), tree_kws: dict | None = None, xticklabels: bool = True, yticklabels: bool = False, figsize: tuple[float, float] = (10.0, 8.0), title: str | None = None, show: bool = True, ax: bool = False, save: str | Path | None = None, ) -> Axes | None: """ Plot a clustered heatmap of variable profiles across samples or groups. Computes z-scores for each variable across samples (or group summaries), then applies hierarchical clustering to visualize variable expression patterns. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` with proteomics annotations. selected_vars : list[str] | None Explicit list of variables to include. When ``None``, all variables are used. group_by : str | None Column in ``adata.obs`` used to group observations. When provided, computes a summary statistic for each group rather than showing individual samples. summary_method : str Method for computing group summaries when ``group_by`` is specified. One of ``"median"`` or ``"mean"`` (alias ``"average"``). linkage_method : str Linkage criterion passed to :func:`scipy.cluster.hierarchy.linkage`. distance_metric : str Distance metric for clustering. One of ``"euclidean"``, ``"manhattan"``, or ``"cosine"``. layer : str | None Optional ``adata.layers`` key to draw quantification values from. When ``None`` the primary matrix ``adata.X`` is used. zero_to_na : bool Replace zeros with ``NaN`` before computing profiles. fill_na : float | int | None Replace ``NaN`` values with the specified constant. skip_na : bool Skip ``NaN`` values when computing group summaries and z-scores. cmap : str Colormap for the heatmap body. margin_color : bool Add a color bar between the column dendrogram and the heatmap. When ``True``, colors by sample (if ``group_by`` is ``None``) or by group (if ``group_by`` is set). order_by : str | None Column in ``adata.obs`` used to order samples (columns). When set, automatically disables column clustering and orders columns by the values of this column. Also displays a margin color bar colored by this column. Cannot be used with ``group_by``. order : str | list | None The order by which to present samples, groups, or categories. If ``order_by`` is ``None`` and ``order`` is ``None``, the existing order is used. If ``order_by`` is ``None`` and ``order`` is not ``None``, ``order`` specifies the column order (samples or groups). If ``order_by`` is not ``None`` and ``order`` is ``None``, the unique values in ``order_by`` are used (categorical order if categorical, sorted order otherwise). If ``order_by`` is not ``None`` and ``order`` is not ``None``, ``order`` defines the order of the unique ``order_by`` values. Values not in ``order`` are excluded. color_scheme : str | dict | Sequence | Colormap | None Palette specification for the margin color bar, forwarded to :func:`proteopy.utils.matplotlib._resolve_color_scheme`. Ignored when neither ``margin_color`` nor ``order_by`` is set. cbar_pos : tuple of (left, bottom, width, height), optional Position of the colorbar axes in the figure. Setting to ``None`` will disable the colorbar. tree_kws : dict, optional Keyword arguments passed to :class:`matplotlib.collections.LineCollection` for the dendrogram lines (e.g. ``colors``, ``linewidths``). row_cluster : bool Perform hierarchical clustering on variables (rows). col_cluster : bool Perform hierarchical clustering on samples/groups (columns). xticklabels : bool Show x-axis tick labels (sample/group names). yticklabels : bool Show y-axis tick labels (variable names). figsize : tuple[float, float] Matplotlib figure size in inches. title : str | None Title for the plot. show : bool Display the figure with :func:`matplotlib.pyplot.show`. ax : bool Return the heatmap :class:`matplotlib.axes.Axes` when ``True``. save : str | Path | None File path for saving the figure. Returns ------- Axes or None Heatmap axes when ``ax`` is ``True``; otherwise ``None``. """ check_proteodata(adata) # Validate summary_method summary_method = summary_method.lower() if summary_method == "average": summary_method = "mean" if summary_method not in ("median", "mean"): raise ValueError( f"summary_method must be 'median' or 'mean', got '{summary_method}'." ) # Validate distance_metric distance_metric = distance_metric.lower() if distance_metric not in ("euclidean", "manhattan", "cosine"): raise ValueError( f"distance_metric must be 'euclidean', 'manhattan', or 'cosine', " f"got '{distance_metric}'." ) # Map metric names to scipy pdist names metric_map = { "euclidean": "euclidean", "manhattan": "cityblock", "cosine": "cosine", } scipy_metric = metric_map[distance_metric] # Validate order_by if order_by is not None: if group_by is not None: raise ValueError( "order_by cannot be used with group_by. When using group_by, " "columns represent groups, not individual samples." ) if order_by not in adata.obs.columns: raise KeyError(f"Column '{order_by}' not found in adata.obs.") # order_by and col_cluster are mutually exclusive; disable clustering if col_cluster: print(( "`order_by` parameter is incompatible with `col_cluster=True`. " "`col_cluster` has been overridden." )) col_cluster = False # Validate order parameter if order is not None: if col_cluster: print(( "`order` parameter is incompatible with `col_cluster=True`. " "`col_cluster` has been overridden." )) col_cluster = False order = list(order) if order_by is None and group_by is None: # order specifies sample names available_samples = set(adata.obs_names) invalid_samples = [s for s in order if s not in available_samples] if invalid_samples: raise KeyError( f"Samples not found in adata.obs_names: {invalid_samples}" ) elif group_by is not None: # order specifies group names; validate against group_by column available_groups = set(adata.obs[group_by].dropna().unique()) invalid_groups = [g for g in order if g not in available_groups] if invalid_groups: raise KeyError( f"Groups not found in adata.obs['{group_by}']: {invalid_groups}" ) # Validation for order_by case is done after we have the data # Extract matrix if layer is None: matrix = adata.X else: if layer not in adata.layers: raise KeyError(f"Layer '{layer}' not found in adata.layers.") matrix = adata.layers[layer] if matrix is None: raise ValueError("Selected matrix is empty.") # Densify if sparse if sparse.issparse(matrix): matrix = matrix.toarray() else: matrix = np.asarray(matrix) # Create DataFrame (obs x var) df = pd.DataFrame( matrix, index=adata.obs_names, columns=adata.var_names, ) # Filter variables if specified if selected_vars is not None: missing_vars = [v for v in selected_vars if v not in df.columns] if missing_vars: raise KeyError( f"Variables not found in adata.var_names: {missing_vars}" ) df = df[selected_vars] if zero_to_na: df = df.replace(0, np.nan) if fill_na is not None: df = df.fillna(fill_na) # Group by if specified if group_by is not None: if group_by not in adata.obs.columns: raise KeyError(f"Column '{group_by}' not found in adata.obs.") groups = adata.obs[group_by] df["__group__"] = groups.values # Compute group summaries # include_groups=False excludes __group__ from the lambda input if summary_method == "median": summary_df = df.groupby("__group__", observed=True).apply( lambda x: x.median(skipna=skip_na), include_groups=False, ) else: summary_df = df.groupby("__group__", observed=True).apply( lambda x: x.mean(skipna=skip_na), include_groups=False, ) # Transpose to get var x group profile_df = summary_df.T else: # Transpose to get var x obs profile_df = df.T # Drop variables with all NaN profile_df = profile_df.dropna(how="all") if profile_df.empty: raise ValueError("No variables remain after removing all-NaN rows.") # Compute z-scores per variable (row) row_means = profile_df.mean(axis=1, skipna=skip_na) row_stds = profile_df.std(axis=1, skipna=skip_na, ddof=0) row_stds = row_stds.replace(0, np.nan) # avoid division by zero z_df = profile_df.sub(row_means, axis=0).div(row_stds, axis=0) # Fill NaN with 0 for clustering z_df_filled = z_df.fillna(0) # Order columns based on order_by and/or order if order_by is not None: # Get order based on obs column values order_col_values = adata.obs.loc[z_df_filled.columns, order_by] if order is not None: # Validate that order values exist in the order_by column available_values = set(order_col_values.unique()) invalid_values = [v for v in order if v not in available_values] if invalid_values: raise KeyError( f"Values not found in adata.obs['{order_by}']: {invalid_values}" ) # Filter to samples whose order_by value is in order, then sort mask = order_col_values.isin(order) filtered_cols = z_df_filled.columns[mask] order_col_values = order_col_values.loc[filtered_cols] # Create categorical with specified order for sorting order_col_values = pd.Categorical( order_col_values, categories=order, ordered=True, ) sorted_idx = ( pd.Series(order_col_values, index=filtered_cols) .sort_values().index ) else: # Use categorical order if categorical, sorted order otherwise if isinstance(order_col_values.dtype, pd.CategoricalDtype): cat_order = list(order_col_values.cat.categories) order_col_values = pd.Categorical( order_col_values, categories=cat_order, ordered=True, ) sorted_idx = pd.Series( order_col_values, index=z_df_filled.columns, ).sort_values().index else: sorted_idx = order_col_values.sort_values().index z_df_filled = z_df_filled[sorted_idx] elif order is not None: # order specifies sample or group names directly # Filter to only columns in order, maintaining order valid_cols = [c for c in order if c in z_df_filled.columns] z_df_filled = z_df_filled[valid_cols] # Build column colors for margin annotation col_colors = None col_names = z_df_filled.columns if order_by is not None: # Color by the order_by column categories = adata.obs.loc[col_names, order_by].values elif margin_color: # Color by sample or group categories = col_names else: categories = None if categories is not None: # Create color palette unique_cats = pd.Series(categories).unique() resolved_colors = _resolve_color_scheme(color_scheme, unique_cats) if resolved_colors is None: resolved_colors = ( sns.color_palette("husl", n_colors=len(unique_cats)) if len(unique_cats) > 0 else [] ) color_map = dict(zip(unique_cats, resolved_colors)) col_colors = pd.Series( [color_map[c] for c in categories], index=col_names, ) # Create clustermap clustermap_kws = dict( method=linkage_method, metric=scipy_metric, row_cluster=row_cluster, col_cluster=col_cluster, cmap=cmap, center=0, figsize=figsize, xticklabels=xticklabels, yticklabels=yticklabels, col_colors=col_colors, tree_kws=tree_kws, ) if cbar_pos is not None: clustermap_kws["cbar_pos"] = cbar_pos clustermap_kws["cbar_kws"] = {"label": "Z-score"} else: clustermap_kws["cbar_pos"] = None g = sns.clustermap(z_df_filled, **clustermap_kws) g.ax_heatmap.set_xlabel("") # Remove y-axis ticks from the margin color bar if present if g.ax_col_colors is not None: g.ax_col_colors.set_yticks([]) if title is not None: g.figure.suptitle(title, y=1.02) plt.tight_layout() if save: g.savefig(save, dpi=300, bbox_inches="tight") if show: plt.show() if ax: return g.ax_heatmap return None