laion_fmri.splits

Predefined train / test splits for the re:vision generalization framework.

LAION-fMRI ships with 12 train/test splits per pool, designed to test generalization across the per-subject shared + unique stimulus pool or across just the cross-subject shared pool.

Pools

  • "shared" — the 1,121 LAION images shared across every subject (non-OOD subset of the shared block).

  • "sub-01", "sub-03", "sub-05", "sub-06", "sub-07" — the 5,833-image per-subject pools (1,121 shared + 4,712 unique).

Splits

The same 12 split names exist in every pool:

  • random_0random_4 — five seeded uniform-random partitions (re:vision baseline).

  • cluster_k5_0cluster_k5_4 — five hold-out-cluster partitions (CLIP-feature k-means; one cluster held out as test). re:vision Method 2.

  • tau — the MMD-matched 80/20 nearest-neighbour-distance split. re:vision Method 1.

  • ood — train = the pool’s regular images, test = all 371 OOD shared images. re:vision Method 3. Test set is identical across pools; train varies with pool.

Loading

The natural entry point is get_train_test_ids() (or get_split_masks() for trial-table masks):

>>> from laion_fmri.splits import get_train_test_ids
>>> train_ids, test_ids = get_train_test_ids("tau", pool="shared")
>>> len(train_ids), len(test_ids)
(897, 224)

OOD type filter

The ood split’s test set spans 9 OOD categories. Pass ood_types= to restrict the test set to a subset:

>>> from laion_fmri.splits import list_ood_types
>>> list_ood_types()
['cropped', 'gabor', 'gaudy', 'illusion-classic', 'illusion-natural', 'relations', 'selfmade', 'shape', 'unusual']
>>> _, test_shape = get_train_test_ids(
...     "ood", pool="shared", ood_types=["shape"],
... )
>>> len(test_shape)
82

The returned image_ids match the label column of every session’s events TSV. Slice betas with the existing get_betas filter system — see Load for details.

Functions

get_split_masks(trials, name, pool[, ...])

Build (train_mask, test_mask) over rows of a trial table.

get_train_test_ids(name, pool[, variant_id, ...])

Convenience: return (train_ids, test_ids) for one variant.

list_ood_types()

Return the 9 OOD categories present in the ood split.

list_pools()

Return every pool that has bundled splits.

list_splits()

Return the 11 split names available in every pool.

load_all_splits(pool)

Load every split for pool{name: Split}.

load_split(name, pool)

Load one bundled split.

laion_fmri.splits.get_split_masks(trials, name: str, pool: str, variant_id: int = 0, ood_types: str | Iterable[str] | None = None) Tuple[ndarray, ndarray][source]

Build (train_mask, test_mask) over rows of a trial table.

The masks are derived by matching trials["label"] against the train/test image-id lists of the requested split. They line up one-to-one with the trials passed in, so you can apply them to any trial-aligned array (betas, features, decoded labels, …).

Parameters:
  • trials (pandas.DataFrame, pandas.Series, np.ndarray or list) – Trial-level labels. If a DataFrame, the "label" column is used; otherwise the input is treated as label values directly.

  • name (str) – One of the 12 split names (see list_splits()).

  • pool (str) – "shared" or a subject id like "sub-01" (see list_pools()).

  • variant_id (int, default 0) – Variant within the split. Almost always 0.

  • ood_types (str, list[str], or None) – Only meaningful when name == "ood". See get_train_test_ids().

Returns:

(train_mask, test_mask) – Boolean arrays, both of length len(trials).

Return type:

tuple of np.ndarray

Examples

>>> sub = laion_fmri.load_subject("sub-01")
>>> trials = pd.concat(
...     sub.get_trial_info(session=sub.get_sessions()).values(),
...     ignore_index=True,
... )
>>> train_mask, test_mask = get_split_masks(
...     trials, "tau", pool="shared",
... )
>>> betas[train_mask], betas[test_mask]
laion_fmri.splits.get_train_test_ids(name: str, pool: str, variant_id: int = 0, ood_types: str | Iterable[str] | None = None) Tuple[List[str], List[str]][source]

Convenience: return (train_ids, test_ids) for one variant.

Most splits have a single variant_id=0. The five random_* and the five cluster_k5_* splits each ARE the variants — pick the split name; variant_id stays 0.

Parameters:

ood_types (str, list[str], or None) – Only meaningful when name == "ood". Restricts the test set to image_ids of these OOD categories (see list_ood_types()). None keeps all 9 categories.

laion_fmri.splits.list_ood_types() List[str][source]

Return the 9 OOD categories present in the ood split.

laion_fmri.splits.list_pools() List[str][source]

Return every pool that has bundled splits.

laion_fmri.splits.list_splits() List[str][source]

Return the 11 split names available in every pool.

laion_fmri.splits.load_all_splits(pool: str) Dict[str, Split][source]

Load every split for pool{name: Split}.

laion_fmri.splits.load_split(name: str, pool: str) Split[source]

Load one bundled split.

Parameters:
Return type:

Split

Classes

Split(name, pool, splitter, params, n_train, ...)

One named split (e.g. "tau") for one pool (e.g. "sub-01").

SplitVariant(variant_id, train_ids, test_ids)

A single train/test image-id partition within a Split.

class laion_fmri.splits.Split(name: str, pool: str, splitter: str, params: Dict, n_train: int, n_test: int, variants: List[SplitVariant] = <factory>)[source]

Bases: object

One named split (e.g. "tau") for one pool (e.g. "sub-01").

property split_family: str

Coarse family.

One of "random", "cluster_k5", "tau", or "ood".

class laion_fmri.splits.SplitVariant(variant_id: int, train_ids: List[str], test_ids: List[str])[source]

Bases: object

A single train/test image-id partition within a Split.