Source code for laion_fmri.embeddings

"""Access pretrained image embeddings for the LAION-fMRI stimuli.

Use after the embeddings have been downloaded via
:func:`laion_fmri.download.download_embeddings` (or
``laion-fmri download-embeddings``).

The embeddings live as one HDF5 file per model, sitting next to the
stimulus images:

.. code-block:: text

   stimuli/
     task-images_stimuli.h5
     task-images_metadata.csv
     task-images_desc-CLIP_embeddings.h5
     task-images_desc-DINOv2_embeddings.h5
     task-images_desc-PEcore_embeddings.h5
     task-images_desc-SigLIP2_embeddings.h5

Each file has three datasets of length 25,052: ``embedding`` (the
``(N, feature_dim)`` float16 matrix), ``image_ids`` (image filenames),
and ``valid`` (per-row validity flag). All four files share the same
``image_ids`` order.

You normally do not construct :class:`Embeddings` directly. Reach it
through the :class:`~laion_fmri.Stimuli` hub:

>>> import laion_fmri
>>> stim = laion_fmri.load_stimuli()
>>> stim.embeddings["CLIP"].shape          # (25052, 1024)
>>> stim.embeddings.get("CLIP", "shared_12rep_LAION_cluster_1003_i0.jpg")

For subject-aligned arrays, use the Subject namespace:

>>> sub = laion_fmri.load_subject("sub-01")
>>> features = sub.embeddings.all("CLIP")  # (n_trials, D)
"""

from __future__ import annotations

from pathlib import Path

import h5py
import numpy as np
import pandas as pd

from laion_fmri._paths import embeddings_h5_path, stimuli_metadata_path
from laion_fmri.config import get_data_dir


#: Models shipped with the LAION-fMRI release. The label is the BIDS
#: ``desc-`` token used in the filename.
AVAILABLE_MODELS = ("CLIP", "DINOv2", "PEcore", "SigLIP2")


[docs] class Embeddings: """Lazy reader for one or more model embedding files. Opens each model's HDF5 on first access and keeps the handle open for the lifetime of the instance. Use as a context manager to explicitly release the handles:: with Stimuli() as stim: v = stim.embeddings.get("CLIP", "img.jpg") Parameters ---------- models : str or iterable[str] Model labels this handle covers (subset of :data:`AVAILABLE_MODELS`). A single string such as ``"CLIP"`` is accepted. data_dir : str or Path, optional Override the configured data directory. """ def __init__(self, models, data_dir=None): self.data_dir = ( Path(data_dir) if data_dir is not None else Path(get_data_dir()) ) if isinstance(models, str): self._models = (models,) else: self._models = tuple(models) for m in self._models: if m not in AVAILABLE_MODELS: raise ValueError( f"Unknown embedding model {m!r}. " f"Available: {list(AVAILABLE_MODELS)}." ) if not embeddings_h5_path(self.data_dir, m).exists(): raise FileNotFoundError( f"Embeddings for {m!r} not found at " f"{embeddings_h5_path(self.data_dir, m)}. " "Run `laion-fmri download-embeddings` first." ) self._h5s: dict[str, h5py.File] = {} self._image_ids: np.ndarray | None = None self._name_to_idx: dict[str, int] | None = None self._meta: pd.DataFrame | None = None # ── lifecycle ────────────────────────────────────────────── def __enter__(self) -> "Embeddings": return self def __exit__(self, *exc) -> None: self.close()
[docs] def close(self) -> None: """Release every open HDF5 handle.""" for h in self._h5s.values(): h.close() self._h5s.clear()
# ── shape / inventory ───────────────────────────────────── @property def models(self) -> list[str]: """Model labels this handle covers, in user-supplied order.""" return list(self._models) @property def image_ids(self) -> np.ndarray: """Image filenames in embedding row order (shared across models).""" if self._image_ids is None: h = self._handle(self._models[0]) raw = h["image_ids"][:] # h5py returns bytes for variable-length strings; decode once. self._image_ids = np.array( [x.decode() if isinstance(x, bytes) else x for x in raw] ) self._name_to_idx = {n: i for i, n in enumerate(self._image_ids)} return self._image_ids def __len__(self) -> int: return len(self.image_ids) def __contains__(self, model_or_pair) -> bool: if isinstance(model_or_pair, str): return model_or_pair in self._models if isinstance(model_or_pair, tuple) and len(model_or_pair) == 2: model, name = model_or_pair if model not in self._models: return False _ = self.image_ids return name in (self._name_to_idx or {}) return False # ── data access ─────────────────────────────────────────── def __getitem__(self, model: str) -> h5py.Dataset: """Return the full ``(N, feature_dim)`` embedding dataset for ``model``. This is the raw h5py Dataset — read into memory with ``emb["CLIP"][:]`` or slice it directly to stream a subset. """ self._require_model(model) return self._handle(model)["embedding"]
[docs] def get(self, model: str, image_name) -> np.ndarray: """Return embedding row(s) for one or many image names. Parameters ---------- model : str One of :data:`AVAILABLE_MODELS`. image_name : str or sequence of str One image filename or a list/array of filenames. Returns ------- np.ndarray ``(feature_dim,)`` if a single name was passed, otherwise ``(n, feature_dim)`` in the requested order. """ self._require_model(model) if isinstance(image_name, str): idx = self._index_of(image_name) return self[model][idx] names = list(image_name) indices = np.array([self._index_of(n) for n in names]) # h5py doesn't accept arbitrary integer arrays; need sorted unique. order = np.argsort(indices) sorted_idx = indices[order] sorted_rows = self[model][sorted_idx, :] # Undo the sort to preserve caller order. inverse = np.empty_like(order) inverse[order] = np.arange(len(order)) return sorted_rows[inverse]
# ── internals ───────────────────────────────────────────── def _handle(self, model: str) -> h5py.File: h = self._h5s.get(model) if h is None: h = h5py.File(embeddings_h5_path(self.data_dir, model), "r") self._h5s[model] = h return h def _require_model(self, model: str) -> None: if model not in self._models: raise KeyError( f"Model {model!r} not loaded by this handle. " f"Loaded: {self._models}." ) def _index_of(self, name: str) -> int: _ = self.image_ids idx = (self._name_to_idx or {}).get(name) if idx is None: raise KeyError(f"Unknown image name: {name!r}") return idx def _metadata(self) -> pd.DataFrame: if self._meta is None: csv_path = stimuli_metadata_path(self.data_dir) if not csv_path.exists(): raise FileNotFoundError( f"Stimulus metadata not found at {csv_path}. " "Run `laion-fmri download-stimuli` first." ) self._meta = pd.read_csv(csv_path) return self._meta
[docs] def load_embeddings(models="all", data_dir=None) -> Embeddings: """Return a lazy embedding reader for one or more models. Parameters ---------- models : "all", str, or iterable[str] ``"all"`` loads every available embedding model. A single model label such as ``"CLIP"`` or an iterable of labels narrows the reader. data_dir : str or Path, optional Override the configured data directory. """ selected = AVAILABLE_MODELS if models == "all" else models return Embeddings(selected, data_dir=data_dir)