harpy.utils.kronos_embedding

harpy.utils.kronos_embedding#

harpy.utils.kronos_embedding(array, embedding_dimension, matched_channels, do_instance_embedding=True, checkpoint_path='hf_hub:MahmoodLab/kronos', hf_auth_token=None, cache_dir=None, model_type='vits16', token_overlap=False, max_value=1, channel_id_pretrained_name='marker_id_pretrained', channel_id_data_specific_name='marker_id', channel_mean_name='marker_mean', channel_std_name='marker_std')#

Compute KRONOS embeddings for multi-channel instance windows using a pre-trained vision transformer.

The input array must have shape (i, c, z, y, x). Channels are matched and filtered to the pre-trained model’s expected schema using matched_channels. Intensities are first scaled to [0, 1] via division by max_value, and then standardized per channel using means and standard deviations provided in matched_channels. The normalized tensors are embedded using KRONOS (e.g., ViT-S/16), producing either a single feature vector of length d=384 if do_instance_embedding is True, else of length 384*c_matched.

Parameters:
  • array (ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]) – Input data of shape (i, c, z, y, x). z must equal 1; it will be squeezed to (i, c, y, x) before embedding.

  • embedding_dimension (int) – Dimensionality of the embedding vectors produced by the model. Set to 384 if do_instance_embedding is True else 384*nr_of_matched_channels.

  • matched_channels (DataFrame) –

    A Pandas DataFrame used to reconcile dataset-specific channels with the pre-trained KRONOS channel set. It must contain at least the following columns (names configurable via arguments below):

    • channel_id_pretrained_name (default: "marker_id_pretrained"): identifier used by the pre-trained checkpoint (ordering reference).

    • channel_id_data_specific_name (default: "marker_id"): channel id of array (the channels present in your data).

    • channel_mean_name (default: "marker_mean"): per-channel mean used for standardization.

    • channel_std_name (default: "marker_std"): per-channel standard deviation used for standardization.

    Only channels in array which index is matched (via matched_channels) to the pre-trained schema are kept.

  • do_instance_embedding (bool (default: True)) – If True`(default), `KRONOS aggregates the computed embedding across channels to produce an output shape of (i, 384). If False output shape is i, c_matched*384. With c_matched the number of channels that were successfully matched between your data and the pre-trained schema, see docstring description of matched_channels.

  • checkpoint_path (str | Path (default: 'hf_hub:MahmoodLab/kronos')) – Path or identifier for the pre-trained weights. By default, loads from the Hugging Face Hub ("hf_hub:MahmoodLab/kronos"). Local filesystem paths are also supported.

  • hf_auth_token (str | None (default: None)) – Optional Hugging Face authentication token. This needs to be provided when loading from Hugging Face hub.

  • cache_dir (str | Path | None (default: None)) – Cache dir for the model when loaded from Hugging Face hub.

  • model_type (str (default: 'vits16')) – Backbone type for the KRONOS encoder (e.g., "vits16"). Must be compatible with the provided checkpoint.

  • token_overlap (bool (default: False)) – If True, use overlapping tokens during feature extraction. In Kronos tutorial workflow, False (default) is used for unsupervised phenotyping.

  • max_value (int (default: 1)) – Maximum expected pixel value in array. The data are divided by this value to map intensities into [0, 1] before per-channel standardization. Choose based on image type (e.g., 255 for 8-bit, 65535 for 16-bit, or 1 for floating-point data already in range).

  • channel_id_pretrained_name (str (default: 'marker_id_pretrained')) – Column name in matched_channels that encodes the pre-trained channel IDs.

  • channel_id_data_specific_name (str (default: 'marker_id')) – Column name in matched_channels that encodes the dataset-specific channel IDs

  • channel_mean_name (str (default: 'marker_mean')) – Column name in matched_channels holding per-channel means used for standardization.

  • channel_std_name (str (default: 'marker_std')) – Column name in matched_channels holding per-channel standard deviations used for standardization..

Return type:

ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]

Returns:

: - If do_instance_embedding=True: an array of shape (i, 384). - If do_instance_embedding=False: an array of shape (i, c_matched*384).

Here, c_matched is the number of channels that were successfully matched between your data and the pre-trained schema.

Notes

Channels not present in matched_channels are dropped.