Source code for laion_fmri.torch_data
"""PyTorch Dataset integration for laion_fmri."""
import importlib.util
import numpy as np
from laion_fmri._paths import stimuli_dir_path
def _check_torch_available():
"""Raise ImportError if torch is not installed."""
if importlib.util.find_spec("torch") is None:
raise ImportError(
"PyTorch is required for LaionFMRIDataset. "
"Install it with: pip install laion-fmri[torch]"
)
def _resolve_rep_indices(trial_info):
"""Return a per-trial ``rep_index`` array.
Real bucket trial TSVs ship only ``session/run/beta_index/
label`` -- no ``rep_index`` column -- so the index is
derived by counting prior occurrences of each stimulus
identifier. When the column is already present (synthetic
fixtures, future schemas), it's used verbatim.
"""
if "rep_index" in trial_info.columns:
return trial_info["rep_index"].to_numpy()
if "label" in trial_info.columns:
ids = trial_info["label"]
elif "stimulus_id" in trial_info.columns:
ids = trial_info["stimulus_id"]
else:
raise ValueError(
"Trial info has neither 'label' nor 'stimulus_id' "
"column; cannot derive rep_index."
)
seen = {}
rep = []
for sid in ids:
rep.append(seen.get(sid, 0))
seen[sid] = seen.get(sid, 0) + 1
return np.array(rep)
[docs]
class LaionFMRIDataset:
"""PyTorch Dataset wrapping one session of a LAION-fMRI subject.
Parameters
----------
subject : Subject
A loaded Subject instance.
session : str
BIDS session ID. Required -- single-trial betas are stored
per session.
roi : str or None
ROI name for voxel selection.
mask : np.ndarray[bool] or None
Custom boolean mask.
nc_threshold : float or None
Minimum noise ceiling for voxel inclusion.
image_transform : callable or None
Transform applied to image tensors.
"""
def __init__(
self,
subject,
session,
roi=None,
mask=None,
nc_threshold=None,
image_transform=None,
):
_check_torch_available()
import torch
self._subject = subject
self._session = session
self._image_transform = image_transform
self._torch = torch
self._betas = subject.get_betas(
session=session,
roi=roi,
mask=mask,
nc_threshold=nc_threshold,
)
self._stim_meta = subject.get_stimulus_metadata()
self._trial_info = subject.get_trial_info(session=session)
self._stim_indices = subject.get_trial_stimulus_indices(
session=session,
)
self._rep_indices = _resolve_rep_indices(self._trial_info)
stim_dir = stimuli_dir_path(subject._data_dir)
self._image_paths = [
stim_dir / fn for fn in self._stim_meta["filename"]
]
def __len__(self):
return len(self._betas)
def __getitem__(self, idx):
from PIL import Image
betas_tensor = self._torch.tensor(
self._betas[idx], dtype=self._torch.float32,
)
stim_idx = int(self._stim_indices[idx])
img = Image.open(
self._image_paths[stim_idx],
).convert("RGB")
img_array = np.array(img, dtype=np.float32) / 255.0
img_tensor = self._torch.tensor(
img_array.transpose(2, 0, 1),
)
if self._image_transform is not None:
img_tensor = self._image_transform(img_tensor)
return {
"betas": betas_tensor,
"image": img_tensor,
"stimulus_id": self._stim_meta.iloc[stim_idx][
"stimulus_id"
],
"session": self._session,
"rep_index": int(self._rep_indices[idx]),
}