Source code for laion_fmri.splits

"""Predefined train / test splits for the re:vision generalization framework.

LAION-fMRI ships with **12 train/test splits per pool**, designed to test
generalization across the per-subject ``shared + unique`` stimulus pool
or across just the cross-subject shared pool.

Pools
-----

* ``"shared"`` — the 1,121 LAION images shared across every subject
  (non-OOD subset of the shared block).
* ``"sub-01"``, ``"sub-03"``, ``"sub-05"``, ``"sub-06"``, ``"sub-07"`` —
  the 5,833-image per-subject pools (1,121 shared + 4,712 unique).

Splits
------

The same 12 split names exist in every pool:

* ``random_0`` … ``random_4`` — five seeded uniform-random partitions
  (re:vision *baseline*).
* ``cluster_k5_0`` … ``cluster_k5_4`` — five hold-out-cluster partitions
  (CLIP-feature k-means; one cluster held out as test). re:vision
  *Method 2*.
* ``tau`` — the MMD-matched 80/20 nearest-neighbour-distance split.
  re:vision *Method 1*.
* ``ood`` — train = the pool's regular images, test = all 371 OOD
  shared images. re:vision *Method 3*. Test set is identical across
  pools; train varies with pool.

Loading
-------

The natural entry point is :func:`get_train_test_ids` (or
:func:`get_split_masks` for trial-table masks):

>>> from laion_fmri.splits import get_train_test_ids
>>> train_ids, test_ids = get_train_test_ids("tau", pool="shared")
>>> len(train_ids), len(test_ids)
(897, 224)

OOD type filter
---------------

The ``ood`` split's test set spans 9 OOD categories. Pass
``ood_types=`` to restrict the test set to a subset:

>>> from laion_fmri.splits import list_ood_types
>>> list_ood_types()
['cropped', 'gabor', 'gaudy', 'illusion-classic', 'illusion-natural', \
'relations', 'selfmade', 'shape', 'unusual']
>>> _, test_shape = get_train_test_ids(
...     "ood", pool="shared", ood_types=["shape"],
... )
>>> len(test_shape)
82

The returned image_ids match the ``label`` column of every session's
events TSV. Slice betas with the existing ``get_betas`` filter system —
see :doc:`/laion_fmri_package/load` for details.
"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

# Per-subject pools follow the BIDS subject IDs already used elsewhere in
# the package (sub-01, sub-03, sub-05, sub-06, sub-07). The "shared" pool
# is cross-subject. No internal participant-code abstraction is exposed.
_SUBJECT_POOLS: Tuple[str, ...] = (
    "sub-01", "sub-03", "sub-05", "sub-06", "sub-07",
)
_SHARED_POOL = "shared"

# All split names available in every pool. 12 names × 6 pools = 72 JSONs.
_SPLIT_NAMES: Tuple[str, ...] = (
    "random_0", "random_1", "random_2", "random_3", "random_4",
    "cluster_k5_0", "cluster_k5_1", "cluster_k5_2",
    "cluster_k5_3", "cluster_k5_4",
    "tau",
    "ood",
)


# OOD categories present in the bundled `ood` split. Same 371 images
# across all pools; only the train side differs by pool. Categories
# come from the filename pattern shared_4rep_OOD_<type>_<...>.{png,jpg}.
_OOD_TYPES: Tuple[str, ...] = (
    "cropped",
    "gabor",
    "gaudy",
    "illusion-classic",
    "illusion-natural",
    "relations",
    "selfmade",
    "shape",
    "unusual",
)
_OOD_FILENAME_RE = re.compile(
    r"^shared_4rep_OOD_([a-zA-Z0-9-]+)_.+\.(png|jpg)$"
)


[docs] def list_ood_types() -> List[str]: """Return the 9 OOD categories present in the ``ood`` split.""" return list(_OOD_TYPES)
def _ood_type_for_image(image_id: str) -> Optional[str]: """Return the OOD category of an image_id, or None if not an OOD image.""" m = _OOD_FILENAME_RE.match(image_id) return m.group(1) if m else None def _validate_ood_types( types: Optional[Union[str, Iterable[str]]], ) -> Optional[Tuple[str, ...]]: """Coerce an ``ood_types`` argument to a tuple of valid type names.""" if types is None: return None if isinstance(types, str): types = [types] out = tuple(types) bad = [t for t in out if t not in _OOD_TYPES] if bad: raise ValueError( f"Unknown ood_types {bad!r}. Valid: {list(_OOD_TYPES)}" ) return out
[docs] @dataclass class SplitVariant: """A single train/test image-id partition within a Split.""" variant_id: int train_ids: List[str] test_ids: List[str] def all_ids(self) -> List[str]: return list(self.train_ids) + list(self.test_ids)
[docs] @dataclass class Split: """One named split (e.g. ``"tau"``) for one pool (e.g. ``"sub-01"``).""" name: str pool: str splitter: str params: Dict n_train: int n_test: int variants: List[SplitVariant] = field(default_factory=list) @property def split_family(self) -> str: """Coarse family. One of ``"random"``, ``"cluster_k5"``, ``"tau"``, or ``"ood"``. """ if self.name.startswith("random_"): return "random" if self.name.startswith("cluster_k5_"): return "cluster_k5" if self.name == "tau": return "tau" if self.name == "ood": return "ood" return "other"
# ── Locator ──────────────────────────────────────────────────── def _data_dir() -> Path: """Path to the bundled split JSONs (under the installed package).""" return Path(__file__).resolve().parent / "data" def _validate_pool(pool: str) -> str: if pool == _SHARED_POOL or pool in _SUBJECT_POOLS: return pool raise ValueError( f"Unknown pool {pool!r}. Valid: " f"{[_SHARED_POOL] + list(_SUBJECT_POOLS)}." ) def _validate_split_name(name: str) -> str: if name not in _SPLIT_NAMES: raise ValueError( f"Unknown split {name!r}. Valid: {list(_SPLIT_NAMES)}." ) return name # ── Public API ─────────────────────────────────────────────────
[docs] def list_pools() -> List[str]: """Return every pool that has bundled splits.""" return [_SHARED_POOL] + list(_SUBJECT_POOLS)
[docs] def list_splits() -> List[str]: """Return the 11 split names available in every pool.""" return list(_SPLIT_NAMES)
[docs] def load_split(name: str, pool: str) -> Split: """Load one bundled split. Parameters ---------- name : str One of the 11 split names (see :func:`list_splits`). pool : str ``"shared"`` or a subject id like ``"sub-01"`` (see :func:`list_pools`). Returns ------- Split """ _validate_split_name(name) _validate_pool(pool) path = _data_dir() / pool / f"{name}.json" if not path.is_file(): raise FileNotFoundError( f"Split {name!r} for pool {pool!r} not found at {path}. " f"Available pools: {list_pools()}." ) raw = json.loads(path.read_text()) variants = [ SplitVariant( variant_id=int(v["variant_id"]), train_ids=list(v["train_ids"]), test_ids=list(v["test_ids"]), ) for v in raw["variants"] ] return Split( name=raw["name"], pool=pool, splitter=raw.get("splitter", ""), params=raw.get("params", {}), n_train=int(raw["n_train"]), n_test=int(raw["n_test"]), variants=variants, )
[docs] def load_all_splits(pool: str) -> Dict[str, Split]: """Load every split for ``pool`` → ``{name: Split}``.""" return {name: load_split(name, pool) for name in _SPLIT_NAMES}
[docs] def get_train_test_ids( name: str, pool: str, variant_id: int = 0, ood_types: Optional[Union[str, Iterable[str]]] = None, ) -> Tuple[List[str], List[str]]: """Convenience: return ``(train_ids, test_ids)`` for one variant. Most splits have a single ``variant_id=0``. The five ``random_*`` and the five ``cluster_k5_*`` splits each ARE the variants — pick the split name; ``variant_id`` stays 0. Parameters ---------- ood_types : str, list[str], or None Only meaningful when ``name == "ood"``. Restricts the test set to image_ids of these OOD categories (see :func:`list_ood_types`). ``None`` keeps all 9 categories. """ sp = load_split(name, pool) for v in sp.variants: if v.variant_id == variant_id: train, test = list(v.train_ids), list(v.test_ids) break else: raise ValueError( f"variant_id={variant_id} not in split {name!r} (pool {pool}). " f"Available variants: {[v.variant_id for v in sp.variants]}" ) types = _validate_ood_types(ood_types) if types is not None: if name != "ood": raise ValueError( "`ood_types` is only valid for the 'ood' split, " f"got name={name!r}." ) keep = set(types) test = [iid for iid in test if _ood_type_for_image(iid) in keep] return train, test
[docs] def get_split_masks( trials, name: str, pool: str, variant_id: int = 0, ood_types: Optional[Union[str, Iterable[str]]] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Build ``(train_mask, test_mask)`` over rows of a trial table. The masks are derived by matching ``trials["label"]`` against the train/test image-id lists of the requested split. They line up one-to-one with the trials passed in, so you can apply them to *any* trial-aligned array (betas, features, decoded labels, …). Parameters ---------- trials : pandas.DataFrame, pandas.Series, np.ndarray or list Trial-level labels. If a ``DataFrame``, the ``"label"`` column is used; otherwise the input is treated as label values directly. name : str One of the 12 split names (see :func:`list_splits`). pool : str ``"shared"`` or a subject id like ``"sub-01"`` (see :func:`list_pools`). variant_id : int, default 0 Variant within the split. Almost always 0. ood_types : str, list[str], or None Only meaningful when ``name == "ood"``. See :func:`get_train_test_ids`. Returns ------- (train_mask, test_mask) : tuple of np.ndarray Boolean arrays, both of length ``len(trials)``. Examples -------- >>> sub = laion_fmri.load_subject("sub-01") >>> trials = pd.concat( ... sub.get_trial_info(session=sub.get_sessions()).values(), ... ignore_index=True, ... ) >>> train_mask, test_mask = get_split_masks( ... trials, "tau", pool="shared", ... ) >>> betas[train_mask], betas[test_mask] # doctest: +SKIP """ train_ids, test_ids = get_train_test_ids( name, pool, variant_id=variant_id, ood_types=ood_types, ) train_set = set(train_ids) test_set = set(test_ids) if ( hasattr(trials, "columns") and "label" in getattr(trials, "columns", ()) ): labels = trials["label"].to_numpy() elif hasattr(trials, "to_numpy"): labels = trials.to_numpy() else: labels = np.asarray(trials) train_mask = np.array([lab in train_set for lab in labels], dtype=bool) test_mask = np.array([lab in test_set for lab in labels], dtype=bool) return train_mask, test_mask
__all__ = [ "Split", "SplitVariant", "list_pools", "list_splits", "list_ood_types", "load_split", "load_all_splits", "get_train_test_ids", "get_split_masks", ]