Source code for laion_fmri.torch_data

"""PyTorch Dataset integration for laion_fmri."""

import importlib.util

import numpy as np


def _check_torch_available():
    """Raise ImportError if torch is not installed."""
    if importlib.util.find_spec("torch") is None:
        raise ImportError(
            "PyTorch is required for LaionFMRIDataset. Install it with: "
            "python -m pip install 'laion-fmri[torch] @ "
            "git+https://github.com/ViCCo-Group/LAION-fMRI.git@main'"
        )


def _resolve_rep_indices(trial_info):
    """Return a per-trial ``rep_index`` array.

    Real bucket trial TSVs ship only ``session/run/beta_index/
    label`` -- no ``rep_index`` column -- so the index is
    derived by counting prior occurrences of each stimulus
    identifier. When the column is already present (synthetic
    fixtures, future schemas), it's used verbatim.
    """
    if "rep_index" in trial_info.columns:
        return trial_info["rep_index"].to_numpy()
    if "label" in trial_info.columns:
        ids = trial_info["label"]
    elif "stimulus_id" in trial_info.columns:
        ids = trial_info["stimulus_id"]
    else:
        raise ValueError(
            "Trial info has neither 'label' nor 'stimulus_id' "
            "column; cannot derive rep_index."
        )
    seen = {}
    rep = []
    for sid in ids:
        rep.append(seen.get(sid, 0))
        seen[sid] = seen.get(sid, 0) + 1
    return np.array(rep)


[docs] class LaionFMRIDataset: """PyTorch Dataset wrapping one session of a LAION-fMRI subject. Parameters ---------- subject : laion_fmri.subject.Subject A loaded subject object. session : str BIDS session ID. Required -- single-trial betas are stored per session. roi : str or None ROI name for voxel selection. mask : np.ndarray[bool] or None Custom boolean mask. nc_threshold : float or None Minimum noise ceiling for voxel inclusion. image_transform : callable or None Transform applied to image tensors. mask_source : ``"anatomical"`` (default) | ``"rsquare"`` Forwarded to :meth:`Subject.get_betas`; see :meth:`Subject.get_brain_mask` for the difference. """ def __init__( self, subject, session, roi=None, mask=None, nc_threshold=None, image_transform=None, mask_source="anatomical", ): _check_torch_available() import torch self._subject = subject self._session = session self._image_transform = image_transform self._torch = torch self._betas = subject.get_betas( session=session, roi=roi, mask=mask, nc_threshold=nc_threshold, mask_source=mask_source, ) # Slice the subject's global trial table down to this session, # preserving global trial indices. trial_table = subject.metadata self._session_rows = trial_table[ trial_table["session"] == session ].reset_index(drop=False).rename(columns={"index": "global_trial"}) self._rep_indices = _resolve_rep_indices( subject.get_trial_info(session=session), ) self._subject = subject def __len__(self): return len(self._betas) def __getitem__(self, idx): betas_tensor = self._torch.tensor( self._betas[idx], dtype=self._torch.float32, ) global_trial = int(self._session_rows.iloc[idx]["global_trial"]) img = self._subject.images.get(global_trial).convert("RGB") img_array = np.array(img, dtype=np.float32) / 255.0 img_tensor = self._torch.tensor( img_array.transpose(2, 0, 1), ) if self._image_transform is not None: img_tensor = self._image_transform(img_tensor) return { "betas": betas_tensor, "image": img_tensor, "image_name": self._session_rows.iloc[idx]["image_name"], "session": self._session, "rep_index": int(self._rep_indices[idx]), }