%load_ext autoreload
%autoreload 2
# %flow mode reactive

import datetime
import sys
import os
import warnings
from pathlib import Path
from typing import Any, Tuple, List, Dict

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.api as sm
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams, subject
from swc.aeon.io import api as aeon_api
from aeon.schema.schemas import social02
data_dir = Path("/ceph/aeon/aeon/code/scratchpad/methods_paper_data") # Change this to your desired directory
os.makedirs(data_dir, exist_ok=True)
cm2px = 5.2  # 1 cm = 5.2 px roughly for top camera
experiments = [
    {
        "name": "social0.2-aeon3",
        "presocial_start": "2024-01-31 11:00:00",
        "presocial_end": "2024-02-08 15:00:00",
        "social_start": "2024-02-09 16:00:00",
        "social_end": "2024-02-23 13:00:00",
        "postsocial_start": "2024-02-25 17:00:00",
        "postsocial_end": "2024-03-02 14:00:00",
    },
    {
        "name": "social0.2-aeon4",
        "presocial_start": "2024-01-31 11:00:00",
        "presocial_end": "2024-02-08 15:00:00",
        "social_start": "2024-02-09 17:00:00",
        "social_end": "2024-02-23 12:00:00",
        "postsocial_start": "2024-02-25 18:00:00",
        "postsocial_end": "2024-03-02 13:00:00",
    },
    {
        "name": "social0.3-aeon3",
        "presocial_start": "2024-06-08 19:00:00",
        "presocial_end": "2024-06-17 13:00:00",
        "social_start": "2024-06-25 11:00:00",
        "social_end": "2024-07-06 13:00:00",
        "postsocial_start": "2024-07-07 16:00:00",
        "postsocial_end": "2024-07-14 14:00:00",
    },
    {
        "name": "social0.3-aeon4",
        "presocial_start": "2024-06-08 19:00:00",
        "presocial_end": "2024-06-17 14:00:00",
        "social_start": "2024-06-19 12:00:00",
        "social_end": "2024-07-03 14:00:00",
        "postsocial_start": "2024-07-04 11:00:00",
        "postsocial_end": "2024-07-13 12:00:00",
    },
    {
        "name": "social0.4-aeon3",
        "presocial_start": "2024-08-16 17:00:00",
        "presocial_end": "2024-08-24 10:00:00",
        "social_start": "2024-08-28 11:00:00",
        "social_end": "2024-09-09 13:00:00",
        "postsocial_start": "2024-09-09 18:00:00",
        "postsocial_end": "2024-09-22 16:00:00",
    },
    {
        "name": "social0.4-aeon4",
        "presocial_start": "2024-08-16 15:00:00",
        "presocial_end": "2024-08-24 10:00:00",
        "social_start": "2024-08-28 10:00:00",
        "social_end": "2024-09-09 01:00:00",
        "postsocial_start": "2024-09-09 15:00:00",
        "postsocial_end": "2024-09-22 16:00:00",
    },
]

Helper functions#

def save_data_to_parquet(
    df: pd.DataFrame,
    experiment_name: str,
    period_name: str,
    data_type: str,
    data_dir: Path
) -> Path:
    """Saves any DataFrame to a parquet file with consistent naming and metadata.

    Args:
        df (pd.DataFrame): Data to save
        experiment_name (str): Name of the experiment
        period_name (str): Period name (presocial, social, postsocial)
        data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)
        data_dir (Path): Directory to save the file

    Returns:
        Path: Path to the saved file
    """
    # Create directory if it doesn't exist
    os.makedirs(data_dir, exist_ok=True)

    # Add period column for reference if not already present
    df = df.copy()
    if 'period' not in df.columns:
        df['period'] = period_name

    # Handle index properly for consistent loading
    if df.index.name and df.index.name != 'time':
        df = df.reset_index()

    # Create filename
    filename = f"{experiment_name}_{period_name}_{data_type}.parquet"
    file_path = data_dir / filename

    print(f"  Saving to {file_path}...")
    # Save to parquet with compression
    df.to_parquet(file_path, compression="snappy")

    # Report file stats
    file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
    memory_usage_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
    print(f"  Saved successfully: {len(df)} rows, {file_size_mb:.2f} MB on disk")

    return file_path


def load_data_from_parquet(
    experiment_name: str | None,
    period: str | None,
    data_type: str,
    data_dir: Path,
    set_time_index: bool = False
) -> pd.DataFrame:
    """Loads saved data from parquet files.

    Args:
        experiment_name (str, optional): Filter by experiment name. If None, load all experiments.
        period (str, optional): Filter by period (presocial, social, postsocial). If None, load all periods.
        data_type (str): Type of data to load (position, patch, foraging, rfid, sleep, explore)
        data_dir (Path): Directory containing parquet files.
        set_time_index (bool, optional): If True, set 'time' column as DataFrame index.

    Returns:
        pd.DataFrame: Combined DataFrame of all matching parquet files.
    """
    if not data_dir.exists():
        print(f"Directory {data_dir} does not exist. No data files found.")
        return pd.DataFrame()

    # Create pattern based on filters
    pattern = ""
    if experiment_name:
        pattern += f"{experiment_name}_"
    else:
        pattern += "*_"

    if period:
        pattern += f"{period}_"
    else:
        pattern += "*_"

    pattern += f"{data_type}.parquet"

    # Find matching files
    matching_files = list(data_dir.glob(pattern))

    if not matching_files:
        print(f"No matching data files found with pattern: {pattern}")
        return pd.DataFrame()

    print(f"Found {len(matching_files)} matching files")

    # Load and concatenate matching files
    dfs = []
    total_rows = 0
    for file in matching_files:
        print(f"Loading {file}...")
        df = pd.read_parquet(file)
        total_rows += len(df)
        dfs.append(df)
        print(f"  Loaded {len(df)} rows")

    # Combine data
    if dfs:
        combined_df = pd.concat(dfs, ignore_index=True)
        if set_time_index and 'time' in combined_df.columns:
            combined_df = combined_df.set_index('time')
        print(f"Combined data: {len(combined_df)} rows")
        return combined_df
    else:
        return pd.DataFrame()


def save_all_experiment_data(
    experiments: list,
    periods: list,
    data_dict: dict,
    data_type: str,
    data_dir: Path
) -> None:
    """Save data for all experiments and periods in a standardized way.
    
    Args:
        experiments (list): List of experiment dictionaries with 'name' field
        data_dict (dict): Nested dictionary with structure {exp_name: {period_name: dataframe}}
        data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)
        data_dir (Path): Directory to save files
        periods (list): List of periods to process
    """
    # Save individual experiment data
    for exp in experiments:
        for period in periods:
            df = data_dict[exp['name']][period]
            if isinstance(df, pd.DataFrame) and not df.empty:
                save_data_to_parquet(
                    df,
                    exp['name'],
                    period,
                    data_type,
                    data_dir
                )


def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:
    """Excises swaps in the position data.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data of a single subject.
        max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.

    Returns:
        pd.DataFrame: DataFrame with swaps excised.
    """
    dt = pos_df.index.diff().total_seconds()
    dx = pos_df["x"].diff()
    dy = pos_df["y"].diff()
    pos_df["inst_speed"] = np.sqrt(dx**2 + dy**2) / dt

    # Identify jumps
    jumps = (pos_df["inst_speed"] > max_speed)
    shift_down = jumps.shift(1)
    shift_down.iloc[0] = False
    shift_up = jumps.shift(-1)
    shift_up.iloc[len(jumps) - 1] = False
    jump_starts = jumps & ~shift_down
    jump_ends = jumps & ~shift_up
    jump_start_indices = np.where(jump_starts)[0]
    jump_end_indices = np.where(jump_ends)[0]

    if np.any(jumps):

        # Ensure the lengths match
        if len(jump_start_indices) > len(jump_end_indices):  # jump-in-progress at start
            jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)
        elif len(jump_start_indices) < len(jump_end_indices):  # jump-in-progress at end
            jump_start_indices = np.insert(jump_start_indices, 0, 0)

        # Excise jumps by setting speed to nan in jump regions and dropping nans
        for start, end in zip(jump_start_indices, jump_end_indices, strict=True):
            pos_df.loc[pos_df.index[start]:pos_df.index[end], "inst_speed"] = np.nan
        pos_df.dropna(subset=["inst_speed"], inplace=True)

    return pos_df

Data loading and saving#

Patch data#

def load_subject_patch_data(
    key: dict[str, str], period_start: str, period_end: str
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Loads subject patch data for a specified time period.

    Args:
        key (dict): The key to filter the subject patch data.
        period_start (str): The start time for the period.
        period_end (str): The end time for the period.

    Returns:
        tuple: A tuple containing:
            - patch_info (pd.DataFrame): Information about patches.
            - block_subject_patch_data (pd.DataFrame): Data for the specified period.
            - block_subject_patch_pref (pd.DataFrame): Preference data for the specified period.
    """
    patch_info = (
        BlockAnalysis.Patch()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(
        "block_start",
        "patch_name",
        "patch_rate",
        "patch_offset",
        "wheel_timestamps",
        as_dict=True,
    )

    block_subject_patch_data = (
        BlockSubjectAnalysis.Patch()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(format="frame")

    block_subject_patch_pref = (
        BlockSubjectAnalysis.Preference()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(format="frame")

    if patch_info:
        patch_info = pd.DataFrame(patch_info)

    if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:
        block_subject_patch_data.reset_index(inplace=True)

    if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:
        block_subject_patch_pref.reset_index(inplace=True)

    return patch_info, block_subject_patch_data, block_subject_patch_pref
def ensure_ts_arr_datetime(array):
    if len(array) == 0:
        return np.array([], dtype="datetime64[ns]")
    else:
        return np.array(array, dtype="datetime64[ns]")
patch_info_dict = {}
subject_patch_data_dict = {}
subject_patch_pref_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}

    # Define periods
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }

    # Initialize nested dictionaries for this experiment
    patch_info_dict[exp["name"]] = {}
    subject_patch_data_dict[exp["name"]] = {}
    subject_patch_pref_dict[exp["name"]] = {}

    # Load data for each period
    for period_name, (period_start, period_end) in periods.items():
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # Load data for this period
        patch_info, block_subject_patch_data, block_subject_patch_pref = (
            load_subject_patch_data(key, period_start, period_end)
        )

        # Drop nans for 'final_preference' columns
        block_subject_patch_pref = block_subject_patch_pref.dropna(
            subset=["final_preference_by_time", "final_preference_by_wheel"]
        )

        # Extra processing on patch_info
        if isinstance(patch_info, pd.DataFrame) and not patch_info.empty:
            # Add experiment_name and period columns for reference
            patch_info.insert(0, "experiment_name", exp["name"])
            patch_info.insert(1, "period", period_name)

        # Extra processing on block_subject_patch_data
        if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:
            # Add period column for reference
            block_subject_patch_data.insert(1, "period", period_name)

            # # Remove dummy patches
            # block_subject_patch_data = block_subject_patch_data[
            #     ~block_subject_patch_data["patch_name"].str.contains("PatchDummy")
            # ]

            # For pre-social and post-social periods check n_subjects per block (should == 1)
            if period_name in ["presocial", "postsocial"]:
                n_subjects = block_subject_patch_data.groupby("block_start")[
                    "subject_name"
                ].nunique()
                if (n_subjects != 1).any():
                    warnings.warn(
                        f"Pre or post social data for {exp['name']} has blocks with more than one "
                        f"subject being tracked. Data needs to be fixed or cleaned."
                    )

        if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:
            # Add period column for reference
            block_subject_patch_pref.insert(1, "period", period_name)

            # # Remove dummy patches
            # block_subject_patch_pref = block_subject_patch_pref[
            #     ~block_subject_patch_pref["patch_name"].str.contains("PatchDummy")
            # ]

            # Ensure timestamps are correct type (datetime64[ns])
            if "pellet_timestamps" in block_subject_patch_data.columns:
                block_subject_patch_data["pellet_timestamps"] = block_subject_patch_data[
                    "pellet_timestamps"
                ].apply(ensure_ts_arr_datetime)

            if "in_patch_rfid_timestamps" in block_subject_patch_data.columns:
                block_subject_patch_data["in_patch_rfid_timestamps"] = (
                    block_subject_patch_data[
                        "in_patch_rfid_timestamps"
                    ].apply(ensure_ts_arr_datetime)
                )

            if "in_patch_timestamps" in block_subject_patch_data.columns:
                block_subject_patch_data["in_patch_timestamps"] = (
                    block_subject_patch_data[
                        "in_patch_timestamps"
                    ].apply(ensure_ts_arr_datetime)
                )

        # Store the data (patch_info as DataFrame now)
        patch_info_dict[exp["name"]][period_name] = patch_info
        subject_patch_data_dict[exp["name"]][period_name] = block_subject_patch_data
        subject_patch_pref_dict[exp["name"]][period_name] = block_subject_patch_pref
display(patch_info)
display(block_subject_patch_data)
display(block_subject_patch_pref)
# Save the data to parquet files
save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=patch_info_dict,
    data_type="patchinfo",
    data_dir=data_dir,
)

save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=subject_patch_data_dict,
    data_type="patch",
    data_dir=data_dir,
)

save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=subject_patch_pref_dict,
    data_type="patchpref",
    data_dir=data_dir,
)
# Load the data from parquet files
combined_social_patch_info_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="social",
    data_type="patchinfo",
    data_dir=data_dir,
)
display(combined_social_patch_info_df)

combined_social_patch_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="social",
    data_type="patch",
    data_dir=data_dir,
)
display(combined_social_patch_df)

combined_social_patch_pref_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="social",
    data_type="patchpref",
    data_dir=data_dir,
)
display(combined_social_patch_pref_df)

Foraging bouts#

def load_foraging_bouts(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads foraging bout data for blocks falling within a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: Concatenated dataframe of foraging bouts for all matching blocks.
                      Returns an empty dataframe with predefined columns if no data found.
    """
    # Fetch block start times within the specified period
    blocks = (
        Block & key & f"block_start >= '{period_start}'" & f"block_end <= '{period_end}'"
    ).fetch("block_start")

    # Retrieve foraging bouts for each block
    bouts = []
    for block_start in blocks:
        block_key = key | {"block_start": str(block_start)}
        bouts.append(get_foraging_bouts(block_key, min_pellets=1))

    # Return concatenated DataFrame or empty fallback
    if bouts:
        return pd.concat(bouts, ignore_index=True)
    else:
        return pd.DataFrame(
            columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]
        )
# Create a dictionary to hold foraging data for each experiment and period
foraging_data_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}

    # Initialize nested dictionary for this experiment
    foraging_data_dict[exp["name"]] = {}

    # Define periods
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }

    # Load data for each period
    for period_name, (period_start, period_end) in periods.items():
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # Load foraging data for this period
        foraging_df = load_foraging_bouts(key, period_start, period_end)

        # Add experiment name as a column if not already present
        if "experiment_name" not in foraging_df.columns:
            foraging_df.insert(0, "experiment_name", exp["name"])

        # Add period column for reference
        foraging_df["period"] = period_name

        # Store the data
        foraging_data_dict[exp["name"]][period_name] = foraging_df
# Save the foraging data to parquet files
save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=foraging_data_dict,
    data_type="foraging",
    data_dir=data_dir,
)
# Load the combined social period foraging data
combined_social_foraging_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="social",
    data_type="foraging",
    data_dir=data_dir,
)
display(combined_social_foraging_df)

RFID data#

def load_rfid_events(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads RFID events data for chunks falling within a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: DataFrame containing RFID events for the specified period.
                      Returns an empty dataframe with predefined columns if no data found.
    """
    # Fetch RFID events within the specified period
    rfid_events_df = (
        streams.RfidReader * streams.RfidReaderRfidEvents
        & key
        & f'chunk_start >= "{period_start}"'
        & f'chunk_start <= "{period_end}"'
    ).fetch(format="frame")

    if rfid_events_df.empty or not isinstance(rfid_events_df, pd.DataFrame):
        # Return empty DataFrame with expected columns if no data found
        return pd.DataFrame(
            columns=[
                "experiment_name",
                "chunk_start",
                "rfid_reader_name",
                "sample_count",
                "timestamps",
                "rfid",
            ]
        )

    # Get subject details for RFID mapping
    subject_detail = subject.SubjectDetail.fetch(format="frame")
    subject_detail.reset_index(inplace=True)

    # Create mapping from RFID to subject ID
    rfid_to_lab_id = dict(zip(subject_detail["lab_id"], subject_detail["subject"]))

    rfid_events_df["rfid"] = [
        [rfid_to_lab_id.get(str(rfid)) for rfid in rfid_array]
        for rfid_array in rfid_events_df["rfid"]
    ]

    # Extract experiment_name and chunk_start from the index before resetting
    rfid_events_df["experiment_name"] = [idx[0] for idx in rfid_events_df.index]
    rfid_events_df["chunk_start"] = [
        idx[3] for idx in rfid_events_df.index
    ]  # Assuming chunk_start is at index 3

    # Reset the index and drop the index column
    rfid_events_df = rfid_events_df.reset_index(drop=True)

    # Reorder columns to put experiment_name first and chunk_start second
    cols = ["experiment_name", "chunk_start"] + [
        col
        for col in rfid_events_df.columns
        if col not in ["experiment_name", "chunk_start"]
    ]
    rfid_events_df = rfid_events_df[cols]

    return rfid_events_df
# Create a dictionary to hold RFID data for each experiment and period
rfid_data_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}

    # Initialize nested dictionary for this experiment
    rfid_data_dict[exp["name"]] = {}

    # Define periods
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }

    # Load data for each period
    for period_name, (period_start, period_end) in periods.items():
        period_start_str = period_start
        period_end_str = period_end

        # Handle datetime objects if needed
        if not isinstance(period_start, str):
            period_start_str = period_start.strftime("%Y-%m-%d %H:%M:%S")
        if not isinstance(period_end, str):
            period_end_str = period_end.strftime("%Y-%m-%d %H:%M:%S")

        # Load RFID data for this period
        rfid_df = load_rfid_events(key, period_start_str, period_end_str)

        # Add experiment name as a column if not already present
        if "experiment_name" not in rfid_df.columns:
            rfid_df.insert(0, "experiment_name", exp["name"])

        # Add period column for reference
        rfid_df["period"] = period_name

        # Store the data
        rfid_data_dict[exp["name"]][period_name] = rfid_df
# Save the RFID data to parquet files
save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=rfid_data_dict,
    data_type="rfid",
    data_dir=data_dir,
)
# Load the combined social period RFID data
combined_social_rfid_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="social",
    data_type="rfid",
    data_dir=data_dir,
)
display(combined_social_rfid_df)

Position data#

def load_position_data(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads position data (centroid tracking) for a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period.
        period_end (str): End datetime of the time period.

    Returns:
        pd.DataFrame: DataFrame containing position data for the specified period.
                     Returns an empty DataFrame if no data found.
    """
    try:
        print(f"  Querying data from {period_start} to {period_end}...")

        # Create chunk restriction for the time period
        chunk_restriction = acquisition.create_chunk_restriction(
            key["experiment_name"], period_start, period_end
        )

        # Fetch centroid tracking data for the specified period
        centroid_df = (
            streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject
            & key
            & {"spinnaker_video_source_name": "CameraTop"}
            & chunk_restriction
        ).fetch(format="frame")

        centroid_df = centroid_df.reset_index()
        centroid_df = centroid_df.rename(
            columns={
                "subject_name": "identity_name",
                "timestamps": "time",
                "subject_likelihood": "identity_likelihood",
            }
        )
        centroid_df = centroid_df.explode(
            ["time", "identity_likelihood", "x", "y", "likelihood"]
        )
        centroid_df = centroid_df[
            [
                "time",
                "experiment_name",
                "identity_name",
                "identity_likelihood",
                "x",
                "y",
                "likelihood",
            ]
        ].set_index("time")

        # Clean up the dataframe
        if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:
            if "spinnaker_video_source_name" in centroid_df.columns:
                centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
            print(f"  Retrieved {len(centroid_df)} rows of position data")
        else:
            print("  No data found for the specified period")

        return centroid_df

    except Exception as e:
        print(
            f"  Error loading position data for {key['experiment_name']} ({period_start} "
            f"to {period_end}): {e}"
        )
        return pd.DataFrame()
# Create a dictionary to hold position data for each experiment and period
position_data_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}

    # Initialize nested dictionary for this experiment
    position_data_dict[exp["name"]] = {}

    # Define periods
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }

    # Load data for each period
    for period_name, (period_start, period_end) in periods.items():
        print(f"  Loading {period_name} period...")

        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # Load position data for this period
        position_df = load_position_data(key, period_start, period_end)
        position_df.reset_index(inplace=True)

        # Add period column for reference if not empty
        if isinstance(position_df, pd.DataFrame) and not position_df.empty:
            position_df["period"] = period_name

        # Store the data
        position_data_dict[exp["name"]][period_name] = position_df

        # Print data size info
        if isinstance(position_df, pd.DataFrame) and not position_df.empty:
            memory_usage_mb = position_df.memory_usage(deep=True).sum() / (1024 * 1024)
            print(
                f"    {period_name}: {len(position_df)} rows, {memory_usage_mb:.2f} MB in memory"
            )
        else:
            print(f"    {period_name}: No data available")
# Save the position data to parquet files
save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=position_data_dict,
    data_type="position",
    data_dir=data_dir,
)
# Load the combined social period position data
combined_social_position_df = load_data_from_parquet(
    experiment_name="social0.4-aeon3",
    period="social",
    data_type="position",
    data_dir=data_dir,
    set_time_index=True,
)
display(combined_social_position_df.sort_index())

Weight data#

def load_weight_data(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads weight data for a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: Weight data for the specified period.
                      Returns an empty dataframe if no data found.
    """
    try:
        weight_df = (
            acquisition.Environment.SubjectWeight
            & key
            & f"chunk_start >= '{period_start}'"
            & f"chunk_start <= '{period_end}'"
        ).fetch(format="frame")

        return weight_df if not weight_df.empty and isinstance(weight_df, pd.DataFrame) else pd.DataFrame()
    except Exception as e:
        print(
            f"Error loading weight data for {key} from {period_start} to {period_end}: {e}"
        )
        return pd.DataFrame()
# Create a dictionary to hold weight data for each experiment and period
weight_data_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}

    # Initialize nested dictionary for this experiment
    weight_data_dict[exp["name"]] = {}

    # Define periods
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }

    # Load data for each period
    for period_name, (period_start, period_end) in periods.items():
        # Convert to datetime if needed (assuming they're already strings in the right format)
        if isinstance(period_start, str):
            period_start_dt = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
            period_end_dt = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
        else:
            period_start_dt = period_start
            period_end_dt = period_end

        # Load weight data for this period
        weight_df = load_weight_data(key, str(period_start_dt), str(period_end_dt))

        # Add experiment name as a column if not already present and data exists
        if isinstance(weight_df, pd.DataFrame) and not weight_df.empty:
            if "experiment_name" not in weight_df.columns:
                weight_df.insert(0, "experiment_name", exp["name"])

            # Add period column for reference
            weight_df["period"] = period_name

        # Store the data
        weight_data_dict[exp["name"]][period_name] = weight_df
# Save the weight data to parquet files
save_all_experiment_data(
    experiments=experiments,
    periods=["presocial", "social", "postsocial"],
    data_dict=weight_data_dict,
    data_type="weight",
    data_dir=data_dir,
)
# Load the combined social period weight data
combined_social_weight_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",  # or whichever experiment you want
    period="social",
    data_type="weight",
    data_dir=data_dir,
)
display(combined_social_weight_df)

Sleep bouts#

def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:
    """Excises swaps in the position data.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data of a single subject.
        max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.

    Returns:
        pd.DataFrame: DataFrame with swaps excised.
    """
    dt = pos_df.index.diff().total_seconds()
    dx = pos_df["x"].diff()
    dy = pos_df["y"].diff()
    pos_df["inst_speed"] = np.sqrt(dx**2 + dy**2) / dt

    # Identify jumps
    jumps = pos_df["inst_speed"] > max_speed
    shift_down = jumps.shift(1)
    shift_down.iloc[0] = False
    shift_up = jumps.shift(-1)
    shift_up.iloc[len(jumps) - 1] = False
    jump_starts = jumps & ~shift_down
    jump_ends = jumps & ~shift_up
    jump_start_indices = np.where(jump_starts)[0]
    jump_end_indices = np.where(jump_ends)[0]

    if np.any(jumps):
        # Ensure the lengths match
        if len(jump_start_indices) > len(jump_end_indices):  # jump-in-progress at start
            jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)
        elif len(jump_start_indices) < len(jump_end_indices):  # jump-in-progress at end
            jump_start_indices = np.insert(jump_start_indices, 0, 0)

        # Excise jumps by setting speed to nan in jump regions and dropping nans
        for start, end in zip(jump_start_indices, jump_end_indices, strict=True):
            pos_df.loc[pos_df.index[start] : pos_df.index[end], "inst_speed"] = np.nan
        pos_df.dropna(subset=["inst_speed"], inplace=True)

    return pos_df
# Given pos_df and animal name, reutrn all sleep bouts in df within the pos_df time period


def sleep_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    move_thresh: float = 4 * cm2px,  # cm -> px
    max_speed: float = 100 * cm2px,  # cm/s -> px/s
) -> pd.DataFrame:
    """Returns sleep bouts for a given animal within the specified position data time period.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data.
        subject (str): Name of the animal to filter by.
        move_thresh (float): Movement (in px) threshold to define sleep bouts.
        max_speed (float): Maximum speed threshold for excising swaps.

    Returns:
        pd.DataFrame: DataFrame containing sleep bouts for the specified animal.
    """
    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame()

    # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
    sleep_win = pd.Timedelta("1m")
    sleep_windows_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    # Create time windows based on start and end time
    data_start_time = animal_data.index.min()
    data_end_time = animal_data.index.max()
    window_starts = pd.date_range(
        start=data_start_time, end=data_end_time, freq=sleep_win
    )

    # <s> Process each time window
    period = animal_data["period"].iloc[0]
    pbar = tqdm(window_starts, desc=f"Processing sleep bouts for {subject} in {period}")
    for win_start in pbar:
        win_end = win_start + sleep_win
        win_data = animal_data[
            (animal_data.index >= win_start) & (animal_data.index < win_end)
        ].copy()
        if len(win_data) < 100:  # skip windows with too little data
            continue

        # Excise id swaps (based on pos / speed jumps)
        # win_data = excise_swaps(win_data, max_speed)

        # Calculate the displacement - maximum distance between any two points in the window
        dx = win_data["x"].max() - win_data["x"].min()
        dy = win_data["y"].max() - win_data["y"].min()
        displacement = np.sqrt(dx**2 + dy**2)

        # If displacement is less than threshold, consider it a sleep bout
        if displacement < move_thresh:
            new_bout = {
                "subject": subject,
                "start": win_start,
                "end": win_end,
                "duration": sleep_win,
                "period": win_data["period"].iloc[0],
            }
            sleep_windows_df = pd.concat(
                [sleep_windows_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # <s> Now merge consecutive sleep windows into continuous bouts
    if sleep_windows_df.empty or not isinstance(sleep_windows_df, pd.DataFrame):
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
    # Initialize the merged bouts dataframe with the first window
    sleep_bouts_df = pd.DataFrame(
        [
            {
                "subject": subject,
                "start": sleep_windows_df.iloc[0]["start"],
                "end": sleep_windows_df.iloc[0]["end"],
                "duration": sleep_windows_df.iloc[0]["duration"],
                "period": sleep_windows_df.iloc[0]["period"],
            }
        ]
    )
    # Iterate through remaining windows and merge consecutive ones
    for i in range(1, len(sleep_windows_df)):
        current_window = sleep_windows_df.iloc[i]
        last_bout = sleep_bouts_df.iloc[-1]

        if current_window["start"] == last_bout["end"]:  # continue bout
            sleep_bouts_df.at[len(sleep_bouts_df) - 1, "end"] = current_window["end"]
            sleep_bouts_df.at[len(sleep_bouts_df) - 1, "duration"] = (
                sleep_bouts_df.iloc[-1]["end"] - sleep_bouts_df.iloc[-1]["start"]
            )
        else:  # start a new bout
            new_bout = {
                "subject": subject,
                "start": current_window["start"],
                "end": current_window["end"],
                "duration": current_window["duration"],
                "period": current_window["period"],
            }
            sleep_bouts_df = pd.concat(
                [sleep_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # Set min bout time
    min_bout_time = pd.Timedelta("2m")
    sleep_bouts_df = sleep_bouts_df[sleep_bouts_df["duration"] >= min_bout_time]

    return sleep_bouts_df
"""Save sleep bouts to parquet files for all experiments and periods."""

# For each experiment, for each period, load pos data, get sleep bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    sleep_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}
    sleep_bouts_data_dict[exp["name"]] = {}
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }
    pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
    for period_name, (period_start, period_end) in pbar_period:
        print(f"  Loading {period_name} period...")
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get sleep bouts for each subject
        subjects = pos_df["identity_name"].unique()
        sleep_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            subject_bouts = sleep_bouts(pos_df, subject)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                sleep_bouts_df = pd.concat(
                    [sleep_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=sleep_bouts_data_dict,
            data_type="sleep",
            data_dir=data_dir,
        )
        print(f"  Saved sleep bouts for {exp['name']} during {period_name} period.")
"""Example usage:"""

sleep_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="presocial",
    data_type="sleep",
    data_dir=data_dir,
    set_time_index=True,
)
display(sleep_df)

Drink bouts#

def drink_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    spout_loc: tuple[float, float],  # x,y spout location in px
    start_radius: float = 4 * 5.2,  # must be within X cm of spout, in px
    move_thresh: float = 2.5 * 5.2,  # during bout must move less than X cm, in px
    min_dur: float = 6,  # min duration of bout in seconds
    max_dur: float = 90,  # max duration of bout in seconds
) -> pd.DataFrame:  # cols: subject, start, end, duration, period
    """Returns drink bouts for a given animal within the specified position data time period."""

    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])

    # Smooth position data to 100ms intervals - only numeric columns
    numeric_cols = animal_data.select_dtypes(include=[np.number]).columns
    animal_data = animal_data[numeric_cols].resample("100ms").mean().interpolate()
    animal_data = animal_data.dropna()

    # Add non-numeric columns back
    animal_data["identity_name"] = subject
    animal_data["experiment_name"] = pos_df["experiment_name"].iloc[0]
    animal_data["period"] = pos_df["period"].iloc[0]

    # Calculate distance from spout
    spout_x, spout_y = spout_loc
    animal_data["dist_to_spout"] = np.sqrt(
        (animal_data["x"] - spout_x) ** 2 + (animal_data["y"] - spout_y) ** 2
    )

    # Find potential bout starts (within start_radius of spout)
    near_spout = animal_data["dist_to_spout"] <= start_radius

    # Get period info
    period = animal_data["period"].iloc[0]

    drink_bouts_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    pbar = tqdm(
        total=len(animal_data), desc=f"Processing drink bouts for {subject} in {period}"
    )
    i = 0
    while i < len(animal_data):
        pbar.update(i - (i - 1))
        # Skip if not near spout
        if not near_spout.iloc[i]:
            i += 1
            continue

        # Found potential bout start
        bout_start_time = animal_data.index[i]
        bout_start_idx = i

        # Track movement during potential bout
        start_x = animal_data["x"].iloc[i]
        start_y = animal_data["y"].iloc[i]

        j = i
        max_displacement = 0

        # Continue while near spout and not moving too much
        while j < len(animal_data):
            current_time = animal_data.index[j]
            elapsed_time = (current_time - bout_start_time).total_seconds()

            # Calculate displacement from bout start position
            current_x = animal_data["x"].iloc[j]
            current_y = animal_data["y"].iloc[j]
            displacement = np.sqrt(
                (current_x - start_x) ** 2 + (current_y - start_y) ** 2
            )
            max_displacement = max(max_displacement, displacement)

            # Check if bout should end
            if max_displacement > move_thresh:
                break

            if elapsed_time > max_dur:
                break

            j += 1

        # Determine bout end
        bout_end_time = (
            animal_data.index[j - 1] if j > bout_start_idx else bout_start_time
        )
        bout_duration = (bout_end_time - bout_start_time).total_seconds()

        # Check if bout meets duration criteria
        if min_dur < bout_duration < max_dur:
            new_bout = {
                "subject": subject,
                "start": bout_start_time,
                "end": bout_end_time,
                "duration": pd.Timedelta(seconds=bout_duration),
                "period": period,
            }
            drink_bouts_df = pd.concat(
                [drink_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )

        # Move to next potential bout (skip past current bout end)
        i = max(j, i + 1)

    pbar.close()
    return drink_bouts_df
"""Save drink bouts to parquet files for all experiments and periods."""

# For each experiment, for each period, load pos data, get drink bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    drink_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}
    drink_bouts_data_dict[exp["name"]] = {}
    pbar_period = tqdm(periods, desc="Processing periods", leave=False)
    for period_name in pbar_period:
        print(f"  Loading {period_name} period...")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get drink bouts for each subject
        subjects = pos_df["identity_name"].unique()
        drink_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            spout_loc = (1280, 500) if "aeon3" in exp["name"] else (1245, 535)
            subject_bouts = drink_bouts(pos_df, subject, spout_loc)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                drink_bouts_df = pd.concat(
                    [drink_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        drink_bouts_data_dict[exp["name"]][period_name] = drink_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=drink_bouts_data_dict,
            data_type="drink",
            data_dir=data_dir,
        )
        print(f"  Saved drink bouts for {exp['name']} during {period_name} period.")
drink_bouts_df

Explore bouts#

# Given pos_df, animal name, nest xy, reutrn all explore bouts in df

nest_center = np.array((1215, 530))
cm2px = 5.2
nest_radius = 14 * cm2px  # 14 cm, in px


def explore_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    nest_center: np.ndarray,
    nest_radius: float = 14 * 5.2,  # 14 cm, in px
    max_speed: float = 100 * 5.2,  # 100 cm/s, in px/s
) -> pd.DataFrame:
    """Returns exploration bouts for a given animal within the specified position data time period.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data.
        subject (str): Name of the animal to filter by.
        nest_center (np.ndarray): Coordinates of the nest center.
        nest_radius (float): Radius of the nest area (default: 14 cm in px).
        max_speed (float): Maximum speed threshold for excising swaps (default: 100 cm/s in px/s).

    Returns:
        pd.DataFrame: DataFrame containing exploration bouts for the specified animal.
    """
    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame()

    # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
    explore_win = pd.Timedelta("1m")
    explore_windows_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    # Create time windows based on start and end time
    data_start_time = animal_data.index.min()
    data_end_time = animal_data.index.max()
    window_starts = pd.date_range(
        start=data_start_time, end=data_end_time, freq=explore_win
    )

    # <s> Process each time window (use tqdm for progress bar)
    period = animal_data["period"].iloc[0]
    pbar = tqdm(window_starts, desc=f"Processing explore bouts for {subject} in {period}")
    for win_start in pbar:
        win_end = win_start + explore_win
        win_data = animal_data[
            (animal_data.index >= win_start) & (animal_data.index < win_end)
        ].copy()
        if len(win_data) < 100:  # skip windows with too little data
            continue

        # Excise id swaps (based on pos / speed jumps)
        win_data = excise_swaps(win_data, max_speed)

        # If majority of time in a window is outside nest, consider it an explore bout
        dx = win_data["x"] - nest_center[0]
        dy = win_data["y"] - nest_center[1]
        distance_from_nest = np.sqrt(dx**2 + dy**2)
        frac_out_nest = (distance_from_nest > nest_radius).sum() / len(win_data)
        if frac_out_nest > 0.5:
            new_bout = {
                "subject": subject,
                "start": win_start,
                "end": win_end,
                "duration": explore_win,
                "period": win_data["period"].iloc[0],
            }
            explore_windows_df = pd.concat(
                [explore_windows_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # <s> Now merge consecutive explore windows into continuous bouts
    if explore_windows_df.empty or not isinstance(explore_windows_df, pd.DataFrame):
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
    # Initialize the merged bouts dataframe with the first window
    explore_bouts_df = pd.DataFrame(
        [
            {
                "subject": subject,
                "start": explore_windows_df.iloc[0]["start"],
                "end": explore_windows_df.iloc[0]["end"],
                "duration": explore_windows_df.iloc[0]["duration"],
                "period": explore_windows_df.iloc[0]["period"],
            }
        ]
    )
    # Iterate through remaining windows and merge consecutive ones
    for i in range(1, len(explore_windows_df)):
        current_window = explore_windows_df.iloc[i]
        last_bout = explore_bouts_df.iloc[-1]

        if current_window["start"] == last_bout["end"]:  # continue bout
            explore_bouts_df.at[len(explore_bouts_df) - 1, "end"] = current_window["end"]
            explore_bouts_df.at[len(explore_bouts_df) - 1, "duration"] = (
                explore_bouts_df.iloc[-1]["end"] - explore_bouts_df.iloc[-1]["start"]
            )
        else:  # start a new bout
            new_bout = {
                "subject": subject,
                "start": current_window["start"],
                "end": current_window["end"],
                "duration": current_window["duration"],
                "period": current_window["period"],
            }
            explore_bouts_df = pd.concat(
                [explore_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    return explore_bouts_df
"""Save explore bouts to parquet files for each experiment and period"""

# For each experiment, for each period, load pos data, get explore bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    sleep_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}

    # get nest center for this exp
    epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
    active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
    roi_locs = dict(
        zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
    )
    points = roi_locs["NestRegion"]["ArrayOfPoint"]
    vertices = np.array([[float(point["X"]), float(point["Y"])] for point in points])
    nest_center = np.mean(vertices, axis=0)

    sleep_bouts_data_dict[exp["name"]] = {}
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }
    pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
    for period_name, (period_start, period_end) in pbar_period:
        print(f"  Loading {period_name} period...")
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get explore bouts for each subject
        subjects = pos_df["identity_name"].unique()
        sleep_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            subject_bouts = explore_bouts(pos_df, subject, nest_center)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                sleep_bouts_df = pd.concat(
                    [sleep_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=sleep_bouts_data_dict,
            data_type="explore",
            data_dir=data_dir,
        )
        print(f"  Saved explore bouts for {exp['name']} during {period_name} period.")


key = {"experiment_name": "social0.2-aeon3"}
epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
roi_locs = dict(
    zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
)
"""Example usage:"""

explore_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="presocial",
    data_type="explore",
    data_dir=data_dir,
    set_time_index=True,
)
display(explore_df)