%load_ext autoreload
%autoreload 2
# %flow mode reactive
import datetime
import sys
import os
import warnings
from pathlib import Path
from typing import Any, Tuple, List, Dict
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.api as sm
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams, subject
from swc.aeon.io import api as aeon_api
from aeon.schema.schemas import social02
data_dir = Path("/ceph/aeon/aeon/code/scratchpad/methods_paper_data") # Change this to your desired directory
os.makedirs(data_dir, exist_ok=True)
cm2px = 5.2 # 1 cm = 5.2 px roughly for top camera
experiments = [
{
"name": "social0.2-aeon3",
"presocial_start": "2024-01-31 11:00:00",
"presocial_end": "2024-02-08 15:00:00",
"social_start": "2024-02-09 16:00:00",
"social_end": "2024-02-23 13:00:00",
"postsocial_start": "2024-02-25 17:00:00",
"postsocial_end": "2024-03-02 14:00:00",
},
{
"name": "social0.2-aeon4",
"presocial_start": "2024-01-31 11:00:00",
"presocial_end": "2024-02-08 15:00:00",
"social_start": "2024-02-09 17:00:00",
"social_end": "2024-02-23 12:00:00",
"postsocial_start": "2024-02-25 18:00:00",
"postsocial_end": "2024-03-02 13:00:00",
},
{
"name": "social0.3-aeon3",
"presocial_start": "2024-06-08 19:00:00",
"presocial_end": "2024-06-17 13:00:00",
"social_start": "2024-06-25 11:00:00",
"social_end": "2024-07-06 13:00:00",
"postsocial_start": "2024-07-07 16:00:00",
"postsocial_end": "2024-07-14 14:00:00",
},
{
"name": "social0.3-aeon4",
"presocial_start": "2024-06-08 19:00:00",
"presocial_end": "2024-06-17 14:00:00",
"social_start": "2024-06-19 12:00:00",
"social_end": "2024-07-03 14:00:00",
"postsocial_start": "2024-07-04 11:00:00",
"postsocial_end": "2024-07-13 12:00:00",
},
{
"name": "social0.4-aeon3",
"presocial_start": "2024-08-16 17:00:00",
"presocial_end": "2024-08-24 10:00:00",
"social_start": "2024-08-28 11:00:00",
"social_end": "2024-09-09 13:00:00",
"postsocial_start": "2024-09-09 18:00:00",
"postsocial_end": "2024-09-22 16:00:00",
},
{
"name": "social0.4-aeon4",
"presocial_start": "2024-08-16 15:00:00",
"presocial_end": "2024-08-24 10:00:00",
"social_start": "2024-08-28 10:00:00",
"social_end": "2024-09-09 01:00:00",
"postsocial_start": "2024-09-09 15:00:00",
"postsocial_end": "2024-09-22 16:00:00",
},
]
Helper functions#
def save_data_to_parquet(
df: pd.DataFrame,
experiment_name: str,
period_name: str,
data_type: str,
data_dir: Path
) -> Path:
"""Saves any DataFrame to a parquet file with consistent naming and metadata.
Args:
df (pd.DataFrame): Data to save
experiment_name (str): Name of the experiment
period_name (str): Period name (presocial, social, postsocial)
data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)
data_dir (Path): Directory to save the file
Returns:
Path: Path to the saved file
"""
# Create directory if it doesn't exist
os.makedirs(data_dir, exist_ok=True)
# Add period column for reference if not already present
df = df.copy()
if 'period' not in df.columns:
df['period'] = period_name
# Handle index properly for consistent loading
if df.index.name and df.index.name != 'time':
df = df.reset_index()
# Create filename
filename = f"{experiment_name}_{period_name}_{data_type}.parquet"
file_path = data_dir / filename
print(f" Saving to {file_path}...")
# Save to parquet with compression
df.to_parquet(file_path, compression="snappy")
# Report file stats
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
memory_usage_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
print(f" Saved successfully: {len(df)} rows, {file_size_mb:.2f} MB on disk")
return file_path
def load_data_from_parquet(
experiment_name: str | None,
period: str | None,
data_type: str,
data_dir: Path,
set_time_index: bool = False
) -> pd.DataFrame:
"""Loads saved data from parquet files.
Args:
experiment_name (str, optional): Filter by experiment name. If None, load all experiments.
period (str, optional): Filter by period (presocial, social, postsocial). If None, load all periods.
data_type (str): Type of data to load (position, patch, foraging, rfid, sleep, explore)
data_dir (Path): Directory containing parquet files.
set_time_index (bool, optional): If True, set 'time' column as DataFrame index.
Returns:
pd.DataFrame: Combined DataFrame of all matching parquet files.
"""
if not data_dir.exists():
print(f"Directory {data_dir} does not exist. No data files found.")
return pd.DataFrame()
# Create pattern based on filters
pattern = ""
if experiment_name:
pattern += f"{experiment_name}_"
else:
pattern += "*_"
if period:
pattern += f"{period}_"
else:
pattern += "*_"
pattern += f"{data_type}.parquet"
# Find matching files
matching_files = list(data_dir.glob(pattern))
if not matching_files:
print(f"No matching data files found with pattern: {pattern}")
return pd.DataFrame()
print(f"Found {len(matching_files)} matching files")
# Load and concatenate matching files
dfs = []
total_rows = 0
for file in matching_files:
print(f"Loading {file}...")
df = pd.read_parquet(file)
total_rows += len(df)
dfs.append(df)
print(f" Loaded {len(df)} rows")
# Combine data
if dfs:
combined_df = pd.concat(dfs, ignore_index=True)
if set_time_index and 'time' in combined_df.columns:
combined_df = combined_df.set_index('time')
print(f"Combined data: {len(combined_df)} rows")
return combined_df
else:
return pd.DataFrame()
def save_all_experiment_data(
experiments: list,
periods: list,
data_dict: dict,
data_type: str,
data_dir: Path
) -> None:
"""Save data for all experiments and periods in a standardized way.
Args:
experiments (list): List of experiment dictionaries with 'name' field
data_dict (dict): Nested dictionary with structure {exp_name: {period_name: dataframe}}
data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)
data_dir (Path): Directory to save files
periods (list): List of periods to process
"""
# Save individual experiment data
for exp in experiments:
for period in periods:
df = data_dict[exp['name']][period]
if isinstance(df, pd.DataFrame) and not df.empty:
save_data_to_parquet(
df,
exp['name'],
period,
data_type,
data_dir
)
def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:
"""Excises swaps in the position data.
Args:
pos_df (pd.DataFrame): DataFrame containing position data of a single subject.
max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.
Returns:
pd.DataFrame: DataFrame with swaps excised.
"""
dt = pos_df.index.diff().total_seconds()
dx = pos_df["x"].diff()
dy = pos_df["y"].diff()
pos_df["inst_speed"] = np.sqrt(dx**2 + dy**2) / dt
# Identify jumps
jumps = (pos_df["inst_speed"] > max_speed)
shift_down = jumps.shift(1)
shift_down.iloc[0] = False
shift_up = jumps.shift(-1)
shift_up.iloc[len(jumps) - 1] = False
jump_starts = jumps & ~shift_down
jump_ends = jumps & ~shift_up
jump_start_indices = np.where(jump_starts)[0]
jump_end_indices = np.where(jump_ends)[0]
if np.any(jumps):
# Ensure the lengths match
if len(jump_start_indices) > len(jump_end_indices): # jump-in-progress at start
jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)
elif len(jump_start_indices) < len(jump_end_indices): # jump-in-progress at end
jump_start_indices = np.insert(jump_start_indices, 0, 0)
# Excise jumps by setting speed to nan in jump regions and dropping nans
for start, end in zip(jump_start_indices, jump_end_indices, strict=True):
pos_df.loc[pos_df.index[start]:pos_df.index[end], "inst_speed"] = np.nan
pos_df.dropna(subset=["inst_speed"], inplace=True)
return pos_df
Data loading and saving#
Patch data#
def load_subject_patch_data(
key: dict[str, str], period_start: str, period_end: str
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Loads subject patch data for a specified time period.
Args:
key (dict): The key to filter the subject patch data.
period_start (str): The start time for the period.
period_end (str): The end time for the period.
Returns:
tuple: A tuple containing:
- patch_info (pd.DataFrame): Information about patches.
- block_subject_patch_data (pd.DataFrame): Data for the specified period.
- block_subject_patch_pref (pd.DataFrame): Preference data for the specified period.
"""
patch_info = (
BlockAnalysis.Patch()
& key
& f"block_start >= '{period_start}'"
& f"block_start <= '{period_end}'"
).fetch(
"block_start",
"patch_name",
"patch_rate",
"patch_offset",
"wheel_timestamps",
as_dict=True,
)
block_subject_patch_data = (
BlockSubjectAnalysis.Patch()
& key
& f"block_start >= '{period_start}'"
& f"block_start <= '{period_end}'"
).fetch(format="frame")
block_subject_patch_pref = (
BlockSubjectAnalysis.Preference()
& key
& f"block_start >= '{period_start}'"
& f"block_start <= '{period_end}'"
).fetch(format="frame")
if patch_info:
patch_info = pd.DataFrame(patch_info)
if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:
block_subject_patch_data.reset_index(inplace=True)
if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:
block_subject_patch_pref.reset_index(inplace=True)
return patch_info, block_subject_patch_data, block_subject_patch_pref
def ensure_ts_arr_datetime(array):
if len(array) == 0:
return np.array([], dtype="datetime64[ns]")
else:
return np.array(array, dtype="datetime64[ns]")
patch_info_dict = {}
subject_patch_data_dict = {}
subject_patch_pref_dict = {}
for exp in experiments:
key = {"experiment_name": exp["name"]}
# Define periods
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
# Initialize nested dictionaries for this experiment
patch_info_dict[exp["name"]] = {}
subject_patch_data_dict[exp["name"]] = {}
subject_patch_pref_dict[exp["name"]] = {}
# Load data for each period
for period_name, (period_start, period_end) in periods.items():
period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
# Load data for this period
patch_info, block_subject_patch_data, block_subject_patch_pref = (
load_subject_patch_data(key, period_start, period_end)
)
# Drop nans for 'final_preference' columns
block_subject_patch_pref = block_subject_patch_pref.dropna(
subset=["final_preference_by_time", "final_preference_by_wheel"]
)
# Extra processing on patch_info
if isinstance(patch_info, pd.DataFrame) and not patch_info.empty:
# Add experiment_name and period columns for reference
patch_info.insert(0, "experiment_name", exp["name"])
patch_info.insert(1, "period", period_name)
# Extra processing on block_subject_patch_data
if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:
# Add period column for reference
block_subject_patch_data.insert(1, "period", period_name)
# # Remove dummy patches
# block_subject_patch_data = block_subject_patch_data[
# ~block_subject_patch_data["patch_name"].str.contains("PatchDummy")
# ]
# For pre-social and post-social periods check n_subjects per block (should == 1)
if period_name in ["presocial", "postsocial"]:
n_subjects = block_subject_patch_data.groupby("block_start")[
"subject_name"
].nunique()
if (n_subjects != 1).any():
warnings.warn(
f"Pre or post social data for {exp['name']} has blocks with more than one "
f"subject being tracked. Data needs to be fixed or cleaned."
)
if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:
# Add period column for reference
block_subject_patch_pref.insert(1, "period", period_name)
# # Remove dummy patches
# block_subject_patch_pref = block_subject_patch_pref[
# ~block_subject_patch_pref["patch_name"].str.contains("PatchDummy")
# ]
# Ensure timestamps are correct type (datetime64[ns])
if "pellet_timestamps" in block_subject_patch_data.columns:
block_subject_patch_data["pellet_timestamps"] = block_subject_patch_data[
"pellet_timestamps"
].apply(ensure_ts_arr_datetime)
if "in_patch_rfid_timestamps" in block_subject_patch_data.columns:
block_subject_patch_data["in_patch_rfid_timestamps"] = (
block_subject_patch_data[
"in_patch_rfid_timestamps"
].apply(ensure_ts_arr_datetime)
)
if "in_patch_timestamps" in block_subject_patch_data.columns:
block_subject_patch_data["in_patch_timestamps"] = (
block_subject_patch_data[
"in_patch_timestamps"
].apply(ensure_ts_arr_datetime)
)
# Store the data (patch_info as DataFrame now)
patch_info_dict[exp["name"]][period_name] = patch_info
subject_patch_data_dict[exp["name"]][period_name] = block_subject_patch_data
subject_patch_pref_dict[exp["name"]][period_name] = block_subject_patch_pref
display(patch_info)
display(block_subject_patch_data)
display(block_subject_patch_pref)
# Save the data to parquet files
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=patch_info_dict,
data_type="patchinfo",
data_dir=data_dir,
)
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=subject_patch_data_dict,
data_type="patch",
data_dir=data_dir,
)
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=subject_patch_pref_dict,
data_type="patchpref",
data_dir=data_dir,
)
# Load the data from parquet files
combined_social_patch_info_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="social",
data_type="patchinfo",
data_dir=data_dir,
)
display(combined_social_patch_info_df)
combined_social_patch_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="social",
data_type="patch",
data_dir=data_dir,
)
display(combined_social_patch_df)
combined_social_patch_pref_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="social",
data_type="patchpref",
data_dir=data_dir,
)
display(combined_social_patch_pref_df)
Foraging bouts#
def load_foraging_bouts(
key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
"""Loads foraging bout data for blocks falling within a specified time period.
Args:
key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
Returns:
pd.DataFrame: Concatenated dataframe of foraging bouts for all matching blocks.
Returns an empty dataframe with predefined columns if no data found.
"""
# Fetch block start times within the specified period
blocks = (
Block & key & f"block_start >= '{period_start}'" & f"block_end <= '{period_end}'"
).fetch("block_start")
# Retrieve foraging bouts for each block
bouts = []
for block_start in blocks:
block_key = key | {"block_start": str(block_start)}
bouts.append(get_foraging_bouts(block_key, min_pellets=1))
# Return concatenated DataFrame or empty fallback
if bouts:
return pd.concat(bouts, ignore_index=True)
else:
return pd.DataFrame(
columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]
)
# Create a dictionary to hold foraging data for each experiment and period
foraging_data_dict = {}
for exp in experiments:
key = {"experiment_name": exp["name"]}
# Initialize nested dictionary for this experiment
foraging_data_dict[exp["name"]] = {}
# Define periods
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
# Load data for each period
for period_name, (period_start, period_end) in periods.items():
period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
# Load foraging data for this period
foraging_df = load_foraging_bouts(key, period_start, period_end)
# Add experiment name as a column if not already present
if "experiment_name" not in foraging_df.columns:
foraging_df.insert(0, "experiment_name", exp["name"])
# Add period column for reference
foraging_df["period"] = period_name
# Store the data
foraging_data_dict[exp["name"]][period_name] = foraging_df
# Save the foraging data to parquet files
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=foraging_data_dict,
data_type="foraging",
data_dir=data_dir,
)
# Load the combined social period foraging data
combined_social_foraging_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="social",
data_type="foraging",
data_dir=data_dir,
)
display(combined_social_foraging_df)
RFID data#
def load_rfid_events(
key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
"""Loads RFID events data for chunks falling within a specified time period.
Args:
key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
Returns:
pd.DataFrame: DataFrame containing RFID events for the specified period.
Returns an empty dataframe with predefined columns if no data found.
"""
# Fetch RFID events within the specified period
rfid_events_df = (
streams.RfidReader * streams.RfidReaderRfidEvents
& key
& f'chunk_start >= "{period_start}"'
& f'chunk_start <= "{period_end}"'
).fetch(format="frame")
if rfid_events_df.empty or not isinstance(rfid_events_df, pd.DataFrame):
# Return empty DataFrame with expected columns if no data found
return pd.DataFrame(
columns=[
"experiment_name",
"chunk_start",
"rfid_reader_name",
"sample_count",
"timestamps",
"rfid",
]
)
# Get subject details for RFID mapping
subject_detail = subject.SubjectDetail.fetch(format="frame")
subject_detail.reset_index(inplace=True)
# Create mapping from RFID to subject ID
rfid_to_lab_id = dict(zip(subject_detail["lab_id"], subject_detail["subject"]))
rfid_events_df["rfid"] = [
[rfid_to_lab_id.get(str(rfid)) for rfid in rfid_array]
for rfid_array in rfid_events_df["rfid"]
]
# Extract experiment_name and chunk_start from the index before resetting
rfid_events_df["experiment_name"] = [idx[0] for idx in rfid_events_df.index]
rfid_events_df["chunk_start"] = [
idx[3] for idx in rfid_events_df.index
] # Assuming chunk_start is at index 3
# Reset the index and drop the index column
rfid_events_df = rfid_events_df.reset_index(drop=True)
# Reorder columns to put experiment_name first and chunk_start second
cols = ["experiment_name", "chunk_start"] + [
col
for col in rfid_events_df.columns
if col not in ["experiment_name", "chunk_start"]
]
rfid_events_df = rfid_events_df[cols]
return rfid_events_df
# Create a dictionary to hold RFID data for each experiment and period
rfid_data_dict = {}
for exp in experiments:
key = {"experiment_name": exp["name"]}
# Initialize nested dictionary for this experiment
rfid_data_dict[exp["name"]] = {}
# Define periods
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
# Load data for each period
for period_name, (period_start, period_end) in periods.items():
period_start_str = period_start
period_end_str = period_end
# Handle datetime objects if needed
if not isinstance(period_start, str):
period_start_str = period_start.strftime("%Y-%m-%d %H:%M:%S")
if not isinstance(period_end, str):
period_end_str = period_end.strftime("%Y-%m-%d %H:%M:%S")
# Load RFID data for this period
rfid_df = load_rfid_events(key, period_start_str, period_end_str)
# Add experiment name as a column if not already present
if "experiment_name" not in rfid_df.columns:
rfid_df.insert(0, "experiment_name", exp["name"])
# Add period column for reference
rfid_df["period"] = period_name
# Store the data
rfid_data_dict[exp["name"]][period_name] = rfid_df
# Save the RFID data to parquet files
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=rfid_data_dict,
data_type="rfid",
data_dir=data_dir,
)
# Load the combined social period RFID data
combined_social_rfid_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="social",
data_type="rfid",
data_dir=data_dir,
)
display(combined_social_rfid_df)
Position data#
def load_position_data(
key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
"""Loads position data (centroid tracking) for a specified time period.
Args:
key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
period_start (str): Start datetime of the time period.
period_end (str): End datetime of the time period.
Returns:
pd.DataFrame: DataFrame containing position data for the specified period.
Returns an empty DataFrame if no data found.
"""
try:
print(f" Querying data from {period_start} to {period_end}...")
# Create chunk restriction for the time period
chunk_restriction = acquisition.create_chunk_restriction(
key["experiment_name"], period_start, period_end
)
# Fetch centroid tracking data for the specified period
centroid_df = (
streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject
& key
& {"spinnaker_video_source_name": "CameraTop"}
& chunk_restriction
).fetch(format="frame")
centroid_df = centroid_df.reset_index()
centroid_df = centroid_df.rename(
columns={
"subject_name": "identity_name",
"timestamps": "time",
"subject_likelihood": "identity_likelihood",
}
)
centroid_df = centroid_df.explode(
["time", "identity_likelihood", "x", "y", "likelihood"]
)
centroid_df = centroid_df[
[
"time",
"experiment_name",
"identity_name",
"identity_likelihood",
"x",
"y",
"likelihood",
]
].set_index("time")
# Clean up the dataframe
if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:
if "spinnaker_video_source_name" in centroid_df.columns:
centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
print(f" Retrieved {len(centroid_df)} rows of position data")
else:
print(" No data found for the specified period")
return centroid_df
except Exception as e:
print(
f" Error loading position data for {key['experiment_name']} ({period_start} "
f"to {period_end}): {e}"
)
return pd.DataFrame()
# Create a dictionary to hold position data for each experiment and period
position_data_dict = {}
for exp in experiments:
key = {"experiment_name": exp["name"]}
# Initialize nested dictionary for this experiment
position_data_dict[exp["name"]] = {}
# Define periods
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
# Load data for each period
for period_name, (period_start, period_end) in periods.items():
print(f" Loading {period_name} period...")
period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
# Load position data for this period
position_df = load_position_data(key, period_start, period_end)
position_df.reset_index(inplace=True)
# Add period column for reference if not empty
if isinstance(position_df, pd.DataFrame) and not position_df.empty:
position_df["period"] = period_name
# Store the data
position_data_dict[exp["name"]][period_name] = position_df
# Print data size info
if isinstance(position_df, pd.DataFrame) and not position_df.empty:
memory_usage_mb = position_df.memory_usage(deep=True).sum() / (1024 * 1024)
print(
f" {period_name}: {len(position_df)} rows, {memory_usage_mb:.2f} MB in memory"
)
else:
print(f" {period_name}: No data available")
# Save the position data to parquet files
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=position_data_dict,
data_type="position",
data_dir=data_dir,
)
# Load the combined social period position data
combined_social_position_df = load_data_from_parquet(
experiment_name="social0.4-aeon3",
period="social",
data_type="position",
data_dir=data_dir,
set_time_index=True,
)
display(combined_social_position_df.sort_index())
Weight data#
def load_weight_data(
key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
"""Loads weight data for a specified time period.
Args:
key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
Returns:
pd.DataFrame: Weight data for the specified period.
Returns an empty dataframe if no data found.
"""
try:
weight_df = (
acquisition.Environment.SubjectWeight
& key
& f"chunk_start >= '{period_start}'"
& f"chunk_start <= '{period_end}'"
).fetch(format="frame")
return weight_df if not weight_df.empty and isinstance(weight_df, pd.DataFrame) else pd.DataFrame()
except Exception as e:
print(
f"Error loading weight data for {key} from {period_start} to {period_end}: {e}"
)
return pd.DataFrame()
# Create a dictionary to hold weight data for each experiment and period
weight_data_dict = {}
for exp in experiments:
key = {"experiment_name": exp["name"]}
# Initialize nested dictionary for this experiment
weight_data_dict[exp["name"]] = {}
# Define periods
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
# Load data for each period
for period_name, (period_start, period_end) in periods.items():
# Convert to datetime if needed (assuming they're already strings in the right format)
if isinstance(period_start, str):
period_start_dt = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end_dt = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
else:
period_start_dt = period_start
period_end_dt = period_end
# Load weight data for this period
weight_df = load_weight_data(key, str(period_start_dt), str(period_end_dt))
# Add experiment name as a column if not already present and data exists
if isinstance(weight_df, pd.DataFrame) and not weight_df.empty:
if "experiment_name" not in weight_df.columns:
weight_df.insert(0, "experiment_name", exp["name"])
# Add period column for reference
weight_df["period"] = period_name
# Store the data
weight_data_dict[exp["name"]][period_name] = weight_df
# Save the weight data to parquet files
save_all_experiment_data(
experiments=experiments,
periods=["presocial", "social", "postsocial"],
data_dict=weight_data_dict,
data_type="weight",
data_dir=data_dir,
)
# Load the combined social period weight data
combined_social_weight_df = load_data_from_parquet(
experiment_name="social0.2-aeon3", # or whichever experiment you want
period="social",
data_type="weight",
data_dir=data_dir,
)
display(combined_social_weight_df)
Sleep bouts#
def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:
"""Excises swaps in the position data.
Args:
pos_df (pd.DataFrame): DataFrame containing position data of a single subject.
max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.
Returns:
pd.DataFrame: DataFrame with swaps excised.
"""
dt = pos_df.index.diff().total_seconds()
dx = pos_df["x"].diff()
dy = pos_df["y"].diff()
pos_df["inst_speed"] = np.sqrt(dx**2 + dy**2) / dt
# Identify jumps
jumps = pos_df["inst_speed"] > max_speed
shift_down = jumps.shift(1)
shift_down.iloc[0] = False
shift_up = jumps.shift(-1)
shift_up.iloc[len(jumps) - 1] = False
jump_starts = jumps & ~shift_down
jump_ends = jumps & ~shift_up
jump_start_indices = np.where(jump_starts)[0]
jump_end_indices = np.where(jump_ends)[0]
if np.any(jumps):
# Ensure the lengths match
if len(jump_start_indices) > len(jump_end_indices): # jump-in-progress at start
jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)
elif len(jump_start_indices) < len(jump_end_indices): # jump-in-progress at end
jump_start_indices = np.insert(jump_start_indices, 0, 0)
# Excise jumps by setting speed to nan in jump regions and dropping nans
for start, end in zip(jump_start_indices, jump_end_indices, strict=True):
pos_df.loc[pos_df.index[start] : pos_df.index[end], "inst_speed"] = np.nan
pos_df.dropna(subset=["inst_speed"], inplace=True)
return pos_df
# Given pos_df and animal name, reutrn all sleep bouts in df within the pos_df time period
def sleep_bouts(
pos_df: pd.DataFrame,
subject: str,
move_thresh: float = 4 * cm2px, # cm -> px
max_speed: float = 100 * cm2px, # cm/s -> px/s
) -> pd.DataFrame:
"""Returns sleep bouts for a given animal within the specified position data time period.
Args:
pos_df (pd.DataFrame): DataFrame containing position data.
subject (str): Name of the animal to filter by.
move_thresh (float): Movement (in px) threshold to define sleep bouts.
max_speed (float): Maximum speed threshold for excising swaps.
Returns:
pd.DataFrame: DataFrame containing sleep bouts for the specified animal.
"""
animal_data = pos_df[pos_df["identity_name"] == subject].copy()
if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
print(f"No position data found for {subject}")
return pd.DataFrame()
# Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
sleep_win = pd.Timedelta("1m")
sleep_windows_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
# Create time windows based on start and end time
data_start_time = animal_data.index.min()
data_end_time = animal_data.index.max()
window_starts = pd.date_range(
start=data_start_time, end=data_end_time, freq=sleep_win
)
# <s> Process each time window
period = animal_data["period"].iloc[0]
pbar = tqdm(window_starts, desc=f"Processing sleep bouts for {subject} in {period}")
for win_start in pbar:
win_end = win_start + sleep_win
win_data = animal_data[
(animal_data.index >= win_start) & (animal_data.index < win_end)
].copy()
if len(win_data) < 100: # skip windows with too little data
continue
# Excise id swaps (based on pos / speed jumps)
# win_data = excise_swaps(win_data, max_speed)
# Calculate the displacement - maximum distance between any two points in the window
dx = win_data["x"].max() - win_data["x"].min()
dy = win_data["y"].max() - win_data["y"].min()
displacement = np.sqrt(dx**2 + dy**2)
# If displacement is less than threshold, consider it a sleep bout
if displacement < move_thresh:
new_bout = {
"subject": subject,
"start": win_start,
"end": win_end,
"duration": sleep_win,
"period": win_data["period"].iloc[0],
}
sleep_windows_df = pd.concat(
[sleep_windows_df, pd.DataFrame([new_bout])], ignore_index=True
)
# </s>
# <s> Now merge consecutive sleep windows into continuous bouts
if sleep_windows_df.empty or not isinstance(sleep_windows_df, pd.DataFrame):
return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
# Initialize the merged bouts dataframe with the first window
sleep_bouts_df = pd.DataFrame(
[
{
"subject": subject,
"start": sleep_windows_df.iloc[0]["start"],
"end": sleep_windows_df.iloc[0]["end"],
"duration": sleep_windows_df.iloc[0]["duration"],
"period": sleep_windows_df.iloc[0]["period"],
}
]
)
# Iterate through remaining windows and merge consecutive ones
for i in range(1, len(sleep_windows_df)):
current_window = sleep_windows_df.iloc[i]
last_bout = sleep_bouts_df.iloc[-1]
if current_window["start"] == last_bout["end"]: # continue bout
sleep_bouts_df.at[len(sleep_bouts_df) - 1, "end"] = current_window["end"]
sleep_bouts_df.at[len(sleep_bouts_df) - 1, "duration"] = (
sleep_bouts_df.iloc[-1]["end"] - sleep_bouts_df.iloc[-1]["start"]
)
else: # start a new bout
new_bout = {
"subject": subject,
"start": current_window["start"],
"end": current_window["end"],
"duration": current_window["duration"],
"period": current_window["period"],
}
sleep_bouts_df = pd.concat(
[sleep_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
)
# </s>
# Set min bout time
min_bout_time = pd.Timedelta("2m")
sleep_bouts_df = sleep_bouts_df[sleep_bouts_df["duration"] >= min_bout_time]
return sleep_bouts_df
"""Save sleep bouts to parquet files for all experiments and periods."""
# For each experiment, for each period, load pos data, get sleep bouts, save to parquet
pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
sleep_bouts_data_dict = {}
key = {"experiment_name": exp["name"]}
sleep_bouts_data_dict[exp["name"]] = {}
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
for period_name, (period_start, period_end) in pbar_period:
print(f" Loading {period_name} period...")
period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
# load pos data for this period
pos_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period_name,
data_type="position",
data_dir=data_dir,
set_time_index=True,
)
# get sleep bouts for each subject
subjects = pos_df["identity_name"].unique()
sleep_bouts_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
for subject in subjects:
subject_bouts = sleep_bouts(pos_df, subject)
if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
sleep_bouts_df = pd.concat(
[sleep_bouts_df, subject_bouts], ignore_index=True
)
# save data dict
sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
save_all_experiment_data(
experiments=[exp],
periods=[period_name],
data_dict=sleep_bouts_data_dict,
data_type="sleep",
data_dir=data_dir,
)
print(f" Saved sleep bouts for {exp['name']} during {period_name} period.")
"""Example usage:"""
sleep_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="presocial",
data_type="sleep",
data_dir=data_dir,
set_time_index=True,
)
display(sleep_df)
Drink bouts#
def drink_bouts(
pos_df: pd.DataFrame,
subject: str,
spout_loc: tuple[float, float], # x,y spout location in px
start_radius: float = 4 * 5.2, # must be within X cm of spout, in px
move_thresh: float = 2.5 * 5.2, # during bout must move less than X cm, in px
min_dur: float = 6, # min duration of bout in seconds
max_dur: float = 90, # max duration of bout in seconds
) -> pd.DataFrame: # cols: subject, start, end, duration, period
"""Returns drink bouts for a given animal within the specified position data time period."""
animal_data = pos_df[pos_df["identity_name"] == subject].copy()
if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
print(f"No position data found for {subject}")
return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
# Smooth position data to 100ms intervals - only numeric columns
numeric_cols = animal_data.select_dtypes(include=[np.number]).columns
animal_data = animal_data[numeric_cols].resample("100ms").mean().interpolate()
animal_data = animal_data.dropna()
# Add non-numeric columns back
animal_data["identity_name"] = subject
animal_data["experiment_name"] = pos_df["experiment_name"].iloc[0]
animal_data["period"] = pos_df["period"].iloc[0]
# Calculate distance from spout
spout_x, spout_y = spout_loc
animal_data["dist_to_spout"] = np.sqrt(
(animal_data["x"] - spout_x) ** 2 + (animal_data["y"] - spout_y) ** 2
)
# Find potential bout starts (within start_radius of spout)
near_spout = animal_data["dist_to_spout"] <= start_radius
# Get period info
period = animal_data["period"].iloc[0]
drink_bouts_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
pbar = tqdm(
total=len(animal_data), desc=f"Processing drink bouts for {subject} in {period}"
)
i = 0
while i < len(animal_data):
pbar.update(i - (i - 1))
# Skip if not near spout
if not near_spout.iloc[i]:
i += 1
continue
# Found potential bout start
bout_start_time = animal_data.index[i]
bout_start_idx = i
# Track movement during potential bout
start_x = animal_data["x"].iloc[i]
start_y = animal_data["y"].iloc[i]
j = i
max_displacement = 0
# Continue while near spout and not moving too much
while j < len(animal_data):
current_time = animal_data.index[j]
elapsed_time = (current_time - bout_start_time).total_seconds()
# Calculate displacement from bout start position
current_x = animal_data["x"].iloc[j]
current_y = animal_data["y"].iloc[j]
displacement = np.sqrt(
(current_x - start_x) ** 2 + (current_y - start_y) ** 2
)
max_displacement = max(max_displacement, displacement)
# Check if bout should end
if max_displacement > move_thresh:
break
if elapsed_time > max_dur:
break
j += 1
# Determine bout end
bout_end_time = (
animal_data.index[j - 1] if j > bout_start_idx else bout_start_time
)
bout_duration = (bout_end_time - bout_start_time).total_seconds()
# Check if bout meets duration criteria
if min_dur < bout_duration < max_dur:
new_bout = {
"subject": subject,
"start": bout_start_time,
"end": bout_end_time,
"duration": pd.Timedelta(seconds=bout_duration),
"period": period,
}
drink_bouts_df = pd.concat(
[drink_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
)
# Move to next potential bout (skip past current bout end)
i = max(j, i + 1)
pbar.close()
return drink_bouts_df
"""Save drink bouts to parquet files for all experiments and periods."""
# For each experiment, for each period, load pos data, get drink bouts, save to parquet
pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
drink_bouts_data_dict = {}
key = {"experiment_name": exp["name"]}
drink_bouts_data_dict[exp["name"]] = {}
pbar_period = tqdm(periods, desc="Processing periods", leave=False)
for period_name in pbar_period:
print(f" Loading {period_name} period...")
# load pos data for this period
pos_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period_name,
data_type="position",
data_dir=data_dir,
set_time_index=True,
)
# get drink bouts for each subject
subjects = pos_df["identity_name"].unique()
drink_bouts_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
for subject in subjects:
spout_loc = (1280, 500) if "aeon3" in exp["name"] else (1245, 535)
subject_bouts = drink_bouts(pos_df, subject, spout_loc)
if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
drink_bouts_df = pd.concat(
[drink_bouts_df, subject_bouts], ignore_index=True
)
# save data dict
drink_bouts_data_dict[exp["name"]][period_name] = drink_bouts_df
save_all_experiment_data(
experiments=[exp],
periods=[period_name],
data_dict=drink_bouts_data_dict,
data_type="drink",
data_dir=data_dir,
)
print(f" Saved drink bouts for {exp['name']} during {period_name} period.")
drink_bouts_df
Explore bouts#
# Given pos_df, animal name, nest xy, reutrn all explore bouts in df
nest_center = np.array((1215, 530))
cm2px = 5.2
nest_radius = 14 * cm2px # 14 cm, in px
def explore_bouts(
pos_df: pd.DataFrame,
subject: str,
nest_center: np.ndarray,
nest_radius: float = 14 * 5.2, # 14 cm, in px
max_speed: float = 100 * 5.2, # 100 cm/s, in px/s
) -> pd.DataFrame:
"""Returns exploration bouts for a given animal within the specified position data time period.
Args:
pos_df (pd.DataFrame): DataFrame containing position data.
subject (str): Name of the animal to filter by.
nest_center (np.ndarray): Coordinates of the nest center.
nest_radius (float): Radius of the nest area (default: 14 cm in px).
max_speed (float): Maximum speed threshold for excising swaps (default: 100 cm/s in px/s).
Returns:
pd.DataFrame: DataFrame containing exploration bouts for the specified animal.
"""
animal_data = pos_df[pos_df["identity_name"] == subject].copy()
if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
print(f"No position data found for {subject}")
return pd.DataFrame()
# Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
explore_win = pd.Timedelta("1m")
explore_windows_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
# Create time windows based on start and end time
data_start_time = animal_data.index.min()
data_end_time = animal_data.index.max()
window_starts = pd.date_range(
start=data_start_time, end=data_end_time, freq=explore_win
)
# <s> Process each time window (use tqdm for progress bar)
period = animal_data["period"].iloc[0]
pbar = tqdm(window_starts, desc=f"Processing explore bouts for {subject} in {period}")
for win_start in pbar:
win_end = win_start + explore_win
win_data = animal_data[
(animal_data.index >= win_start) & (animal_data.index < win_end)
].copy()
if len(win_data) < 100: # skip windows with too little data
continue
# Excise id swaps (based on pos / speed jumps)
win_data = excise_swaps(win_data, max_speed)
# If majority of time in a window is outside nest, consider it an explore bout
dx = win_data["x"] - nest_center[0]
dy = win_data["y"] - nest_center[1]
distance_from_nest = np.sqrt(dx**2 + dy**2)
frac_out_nest = (distance_from_nest > nest_radius).sum() / len(win_data)
if frac_out_nest > 0.5:
new_bout = {
"subject": subject,
"start": win_start,
"end": win_end,
"duration": explore_win,
"period": win_data["period"].iloc[0],
}
explore_windows_df = pd.concat(
[explore_windows_df, pd.DataFrame([new_bout])], ignore_index=True
)
# </s>
# <s> Now merge consecutive explore windows into continuous bouts
if explore_windows_df.empty or not isinstance(explore_windows_df, pd.DataFrame):
return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
# Initialize the merged bouts dataframe with the first window
explore_bouts_df = pd.DataFrame(
[
{
"subject": subject,
"start": explore_windows_df.iloc[0]["start"],
"end": explore_windows_df.iloc[0]["end"],
"duration": explore_windows_df.iloc[0]["duration"],
"period": explore_windows_df.iloc[0]["period"],
}
]
)
# Iterate through remaining windows and merge consecutive ones
for i in range(1, len(explore_windows_df)):
current_window = explore_windows_df.iloc[i]
last_bout = explore_bouts_df.iloc[-1]
if current_window["start"] == last_bout["end"]: # continue bout
explore_bouts_df.at[len(explore_bouts_df) - 1, "end"] = current_window["end"]
explore_bouts_df.at[len(explore_bouts_df) - 1, "duration"] = (
explore_bouts_df.iloc[-1]["end"] - explore_bouts_df.iloc[-1]["start"]
)
else: # start a new bout
new_bout = {
"subject": subject,
"start": current_window["start"],
"end": current_window["end"],
"duration": current_window["duration"],
"period": current_window["period"],
}
explore_bouts_df = pd.concat(
[explore_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
)
# </s>
return explore_bouts_df
"""Save explore bouts to parquet files for each experiment and period"""
# For each experiment, for each period, load pos data, get explore bouts, save to parquet
pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
sleep_bouts_data_dict = {}
key = {"experiment_name": exp["name"]}
# get nest center for this exp
epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
roi_locs = dict(
zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
)
points = roi_locs["NestRegion"]["ArrayOfPoint"]
vertices = np.array([[float(point["X"]), float(point["Y"])] for point in points])
nest_center = np.mean(vertices, axis=0)
sleep_bouts_data_dict[exp["name"]] = {}
periods = {
"presocial": (exp["presocial_start"], exp["presocial_end"]),
"social": (exp["social_start"], exp["social_end"]),
"postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}
pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
for period_name, (period_start, period_end) in pbar_period:
print(f" Loading {period_name} period...")
period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
# load pos data for this period
pos_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period_name,
data_type="position",
data_dir=data_dir,
set_time_index=True,
)
# get explore bouts for each subject
subjects = pos_df["identity_name"].unique()
sleep_bouts_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period"]
)
for subject in subjects:
subject_bouts = explore_bouts(pos_df, subject, nest_center)
if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
sleep_bouts_df = pd.concat(
[sleep_bouts_df, subject_bouts], ignore_index=True
)
# save data dict
sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
save_all_experiment_data(
experiments=[exp],
periods=[period_name],
data_dict=sleep_bouts_data_dict,
data_type="explore",
data_dir=data_dir,
)
print(f" Saved explore bouts for {exp['name']} during {period_name} period.")
key = {"experiment_name": "social0.2-aeon3"}
epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
roi_locs = dict(
zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
)
"""Example usage:"""
explore_df = load_data_from_parquet(
experiment_name="social0.2-aeon3",
period="presocial",
data_type="explore",
data_dir=data_dir,
set_time_index=True,
)
display(explore_df)