Source code for swc.aeon.io.reader

"""Module for reading data from raw files in an Aeon dataset."""

import datetime
import json
import os
from pathlib import Path
from typing import Any

import harp
import numpy as np
import pandas as pd
from dotmap import DotMap
from pandas._typing import DtypeArg, SequenceNotStr

from swc.aeon.io.api import Reader, chunk_key


[docs] class Harp(Reader): """Extracts data from raw binary files encoded using the Harp protocol.""" def __init__(self, pattern: str, columns: SequenceNotStr[str], extension="bin"): """Initialize the object.""" super().__init__(pattern, columns, extension)
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads data from the specified Harp binary file. Args: path: Path to the Harp binary file. Returns: A DataFrame representing the data extracted from the Harp binary file. """ return harp.read(path, columns=self.columns)
[docs] class Chunk(Reader): """Extracts path and epoch information from chunk files in the dataset.""" def __init__( self, reader: Reader | None = None, pattern: str | None = None, extension: str | None = None ): """Initialize the object with optional reader, pattern, and file extension.""" if isinstance(reader, Reader): pattern = reader.pattern extension = reader.extension elif pattern is None or extension is None: raise ValueError("reader must be specified if pattern or extension are None.") super().__init__(pattern, columns=("path", "epoch"), extension=extension)
[docs] def read(self, path: Path) -> pd.DataFrame: """Returns path and epoch information for the chunk associated with the specified file. Args: path: Path to the data file. Returns: A DataFrame representing the path and epoch information for the specified file. """ epoch, chunk = chunk_key(path) data = {"path": path, "epoch": epoch} return pd.DataFrame(data, index=pd.Series(chunk), columns=self.columns)
[docs] class Metadata(Reader): """Extracts metadata information from all epochs in the dataset.""" def __init__(self, pattern="Metadata"): """Initialize the object with the specified pattern.""" super().__init__(pattern, columns=("workflow", "commit", "metadata"), extension="yml")
[docs] def read(self, path: Path) -> pd.DataFrame: """Returns epoch metadata stored in the specified file. Args: path: Path to the data file. Returns: A DataFrame representing the epoch metadata stored in the specified file. """ epoch_str = path.parts[-2] date_str, time_str = epoch_str.split("T") time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) with open(path) as fp: metadata = json.load(fp) workflow = metadata.pop("Workflow") commit = metadata.pop("Commit", pd.NA) data = {"workflow": workflow, "commit": commit, "metadata": [DotMap(metadata)]} return pd.DataFrame(data, index=pd.Series(time), columns=self.columns)
[docs] class Csv(Reader): """Extracts data from comma-separated (CSV) text files. The first column stores the Aeon timestamp, in seconds. """ def __init__( self, pattern: str, columns: SequenceNotStr[str], dtype: DtypeArg | None = None, extension="csv" ): """Initialize the object with the specified pattern, columns, and data type.""" super().__init__(pattern, columns, extension) self.dtype = dtype self._names = tuple(columns)
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads data from the specified CSV text file. If the file is non-empty, the first column is assumed to contain the Aeon timestamp and is set as the index of the DataFrame. If the file is empty, pandas defaults to using pandas.RangeIndex as the index. Args: path: Path to the CSV text file. Returns: A DataFrame representing the data extracted from the CSV text file. """ return pd.read_csv( path, header=0, names=self._names, dtype=self.dtype, index_col=0 if path.stat().st_size else None, )
[docs] class JsonList(Reader): """Extracts data from .jsonl files, where the key "seconds" stores the Aeon timestamp.""" columns: SequenceNotStr[str] """ Column labels to extract from the dictionary stored in the `root_key` of the JSON object. Each column name must correspond to a key in the dictionary stored in the `root_key`. Defaults to an empty tuple, i.e. the JSON objects are read as-is. """ root_key: str """The key in the JSON object that contains the data. Defaults to "value".""" def __init__( self, pattern: str, columns: SequenceNotStr[str] = (), root_key: str = "value", extension: str = "jsonl", ): """Initialize the object with the specified pattern, columns, and root key.""" super().__init__(pattern, columns, extension) self.columns = columns self.root_key = root_key
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads data from the specified jsonl file. Args: path: Path to the jsonl file to read. The file must contain a "seconds" key that stores the Aeon timestamp, and the `root_key` must contain a dictionary with keys corresponding to the specified `columns`. Returns: A DataFrame with "seconds" as the index, other keys as columns, and the specified columns extracted from the `root_key` dictionary (if any). """ with open(path) as f: df = pd.read_json(f, lines=True) df.set_index("seconds", inplace=True) for column in self.columns: df[column] = df[self.root_key].apply(lambda x: x[column]) # noqa B023 return df
[docs] class Subject(Csv): """Extracts metadata for subjects entering and exiting the environment. Columns: - id (str): Unique identifier of a subject in the environment. - weight (float): Weight measurement of the subject on entering or exiting the environment. - event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=("id", "weight", "event"))
[docs] class Log(Csv): """Extracts message log data. Columns: - priority (str): Priority level of the message. - type (str): Type of the log message. - message (str): Log message data. Can be structured using tab separated values. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=("priority", "type", "message"))
[docs] class Heartbeat(Harp): """Extract periodic heartbeat event data. Columns: - second (int): The whole second corresponding to the heartbeat, in seconds. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=("second",))
[docs] class Encoder(Harp): """Extract magnetic encoder data. Columns: - angle (float): Absolute angular position, in radians, of the magnetic encoder. - intensity (float): Intensity of the magnetic field. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=("angle", "intensity"))
[docs] class Position(Harp): """Extract 2D position tracking data for a specific camera. Columns: - x (float): x-coordinate of the object center of mass. - y (float): y-coordinate of the object center of mass. - angle (float): angle, in radians, of the ellipse fit to the object. - major (float): length, in pixels, of the major axis of the ellipse fit to the object. - minor (float): length, in pixels, of the minor axis of the ellipse fit to the object. - area (float): number of pixels in the object mass. - id (float): unique tracking ID of the object in a frame. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=("x", "y", "angle", "major", "minor", "area", "id"))
[docs] class BitmaskEvent(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. """ value: int """The unique event code to match against the digital I/O data.""" tag: str """A tag/label to assign to the event code for identification.""" def __init__(self, pattern: str, value: int, tag: str): """Initialize the object with specified pattern, value, and tag.""" super().__init__(pattern, columns=("event",)) self.value = value self.tag = tag
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads a specific event code from digital data. Each data value is matched against the unique event identifier. Args: path: Path to the Harp binary file. Returns: A DataFrame representing the event data extracted from the Harp binary file. """ data = super().read(path) data = data[(data.event & self.value) == self.value] data["event"] = self.tag return data
[docs] class DigitalBitmask(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. """ mask: int """The bitmask to match against changes in the digital I/O data.""" def __init__(self, pattern: str, mask: int, columns: SequenceNotStr[str]): """Initialize the object with specified pattern, mask, and columns.""" super().__init__(pattern, columns) self.mask = mask
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads a specific event code from digital data. Each data value is checked against the specified bitmask. Args: path: Path to the Harp binary file. Returns: A DataFrame representing the bitmask data extracted from the Harp binary file. """ data = super().read(path) state = data[self.columns] & self.mask return state[(state.diff() != 0).values] != 0
[docs] class Video(Csv): """Extracts video frame metadata. Columns: - hw_counter (int): Hardware frame counter value for the current frame. - hw_timestamp (int): Internal camera timestamp for the current frame. - _frame (int): Frame index in the video file. - _path (str): Path to the video file. - _epoch (str): Epoch name associated with the video file. """ def __init__(self, pattern: str): """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=("hw_counter", "hw_timestamp", "_frame", "_path", "_epoch")) self._rawcolumns = ("time",) + tuple(self.columns[0:2])
[docs] def read(self, path: Path) -> pd.DataFrame: """Reads video metadata from the specified file. Args: path: Path to the video metadata CSV file. Returns: A DataFrame containing the video metadata. """ data = pd.read_csv(path, header=0, names=self._rawcolumns) data["_frame"] = data.index data["_path"] = os.path.splitext(path)[0] + ".avi" data["_epoch"] = path.parts[-3] data.set_index("time", inplace=True) return data
[docs] class Pose(Harp): """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. Columns: - class (int): Int ID of a subject in the environment. - class_likelihood (float): Likelihood of the subject's identity. - part (str): Bodypart on the subject. - part_likelihood (float): Likelihood of the specified bodypart. - x (float): X-coordinate of the bodypart. - y (float): Y-coordinate of the bodypart. """ def __init__(self, pattern: str, model_root: str | None = None): """Pose reader constructor. The pattern for this reader should typically be `<device>_<hpcnode>_<jobid>*`. If a register prefix is required, the pattern should end with a trailing underscore, e.g. `Camera_202_*`. Otherwise, the pattern should include a common prefix for the pose model folder excluding the trailing underscore, e.g. `Camera_model-dir*`. """ super().__init__(pattern, columns=()) self._model_root = model_root self._pattern_offset = pattern.rfind("_") + 1
[docs] def read(self, path: Path, include_model: bool = False) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file. Args: path: Path to the Harp binary file. include_model: Specifies whether to include the path to the pose tracking model. Returns: A DataFrame representing the data extracted from the Harp binary file. """ # Get config file from `file`, then bodyparts from config file. model_dir = Path(path.stem[self._pattern_offset :].replace("_", "/")).parent # Check if model directory exists in local or shared directories. # Local directory is prioritized over shared directory. local_config_file_dir = path.parent / model_dir shared_config_file_dir = Path(self._model_root) / model_dir if self._model_root else None if local_config_file_dir.exists(): config_file_dir = local_config_file_dir elif shared_config_file_dir and shared_config_file_dir.exists(): config_file_dir = shared_config_file_dir else: msg = f"Cannot find model dir in local directory: {local_config_file_dir}" if shared_config_file_dir: msg += f" or shared directory: {shared_config_file_dir}" raise FileNotFoundError(msg) config_file = self.get_config_file(config_file_dir) identities = self.get_class_names(config_file) parts = self.get_bodyparts(config_file) # Using bodyparts, assign column names to Harp register values, and read data in default format. BONSAI_SLEAP_V2 = 0.2 BONSAI_SLEAP_V3 = 0.3 try: # Bonsai.Sleap0.2 bonsai_sleap_v = BONSAI_SLEAP_V2 columns = ["identity", "identity_likelihood"] for part in parts: columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) self.columns = columns data = super().read(path) except ValueError: # column mismatch; Bonsai.Sleap0.3 bonsai_sleap_v = BONSAI_SLEAP_V3 columns = ["identity"] columns.extend([f"{identity}_likelihood" for identity in identities]) for part in parts: columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) self.columns = columns data = super().read(path) # combine all identity_likelihood cols into a single column if bonsai_sleap_v == BONSAI_SLEAP_V3: identity_likelihood = data.apply( lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1, ) data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) data.insert(1, "identity_likelihood", identity_likelihood) # Replace identity indices with identity labels data = self.class_int2str(data, identities) # Set new columns, and reformat `data`. n_parts = len(parts) new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] new_data = np.empty((len(data) * n_parts, len(new_columns)), dtype="O") new_index = np.empty(len(data) * n_parts, dtype=data.index.values.dtype) for i, part in enumerate(parts): min_col = 2 + i * 3 max_col = 2 + (i + 1) * 3 new_data[i::n_parts, 0:2] = data.values[:, 0:2] new_data[i::n_parts, 2] = part new_data[i::n_parts, 3:6] = data.values[:, min_col:max_col] new_index[i::n_parts] = data.index.values data = pd.DataFrame(new_data, new_index, columns=new_columns) # Set model column using model_dir if include_model: data["model"] = pd.Series(model_dir) return data
[docs] @staticmethod def get_class_names(config_file_path: Path) -> list[str]: """Returns a list of classes from a model's config file.""" with open(config_file_path) as f: config = json.load(f) if config_file_path.stem != "confmap_config": # SLEAP raise ValueError(f"The model config file '{config_file_path}' is not supported.") try: heads = config["model"]["heads"] class_vectors = Pose._recursive_lookup(heads, "class_vectors") if class_vectors is not None: return class_vectors["classes"] else: return list[str]() except KeyError as err: raise KeyError(f"Cannot find class_vectors in {config_file_path}.") from err
[docs] @staticmethod def get_bodyparts(config_file_path: Path) -> list[str]: """Returns a list of bodyparts from a model's config file.""" parts = [] with open(config_file_path) as f: config = json.load(f) if config_file_path.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] parts = [f"anchor_{Pose._find_nested_key(heads, 'anchor_part')}"] parts += Pose._find_nested_key(heads, "part_names") except KeyError as err: raise KeyError(f"Cannot find anchor or bodyparts in {config_file_path}.") from err return parts
[docs] @staticmethod def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame: """Converts a class integer in a tracking data dataframe to its associated string (subject id). Args: data: DataFrame containing a column named "identity" with integer class identifiers. classes: List of class names corresponding to the integer identifiers in the "identity" column. Returns: DataFrame with the "identity" column converted to string class names. """ if classes: identity_mapping = dict(enumerate(classes)) data["identity"] = data["identity"].replace(identity_mapping) return data
[docs] @classmethod def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): config_file = config_file_dir / f break if config_file is None: raise FileNotFoundError(f"Cannot find config file in {config_file_dir}") return config_file
@staticmethod def _find_nested_key(obj: dict, key: str) -> Any: """Returns the value of the first found nested key.""" value = Pose._recursive_lookup(obj, key) if value is None: raise KeyError(key) return value @staticmethod def _recursive_lookup(obj: Any, key: str) -> Any: """Returns the value of the first found nested key.""" if isinstance(obj, dict): if found := obj.get(key): # found it! return found for item in obj.values(): if found := Pose._recursive_lookup(item, key): return found elif isinstance(obj, list): for item in obj: if found := Pose._recursive_lookup(item, key): return found # pragma: no cover