{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"# %flow mode reactive\n",
"\n",
"import datetime\n",
"import sys\n",
"import os\n",
"import warnings\n",
"from pathlib import Path\n",
"from typing import Any, Tuple, List, Dict\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import plotly\n",
"import plotly.express as px\n",
"import plotly.graph_objs as go\n",
"import statsmodels.api as sm\n",
"from plotly.subplots import make_subplots\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import datajoint as dj\n",
"from aeon.dj_pipeline.analysis.block_analysis import *\n",
"from aeon.dj_pipeline import acquisition, streams, subject\n",
"from swc.aeon.io import api as aeon_api\n",
"from aeon.schema.schemas import social02"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_dir = Path(\"/ceph/aeon/aeon/code/scratchpad/methods_paper_data\") # Change this to your desired directory\n",
"os.makedirs(data_dir, exist_ok=True)\n",
"cm2px = 5.2 # 1 cm = 5.2 px roughly for top camera"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiments = [\n",
" {\n",
" \"name\": \"social0.2-aeon3\",\n",
" \"presocial_start\": \"2024-01-31 11:00:00\",\n",
" \"presocial_end\": \"2024-02-08 15:00:00\",\n",
" \"social_start\": \"2024-02-09 16:00:00\",\n",
" \"social_end\": \"2024-02-23 13:00:00\",\n",
" \"postsocial_start\": \"2024-02-25 17:00:00\",\n",
" \"postsocial_end\": \"2024-03-02 14:00:00\",\n",
" },\n",
" {\n",
" \"name\": \"social0.2-aeon4\",\n",
" \"presocial_start\": \"2024-01-31 11:00:00\",\n",
" \"presocial_end\": \"2024-02-08 15:00:00\",\n",
" \"social_start\": \"2024-02-09 17:00:00\",\n",
" \"social_end\": \"2024-02-23 12:00:00\",\n",
" \"postsocial_start\": \"2024-02-25 18:00:00\",\n",
" \"postsocial_end\": \"2024-03-02 13:00:00\",\n",
" },\n",
" {\n",
" \"name\": \"social0.3-aeon3\",\n",
" \"presocial_start\": \"2024-06-08 19:00:00\",\n",
" \"presocial_end\": \"2024-06-17 13:00:00\",\n",
" \"social_start\": \"2024-06-25 11:00:00\",\n",
" \"social_end\": \"2024-07-06 13:00:00\",\n",
" \"postsocial_start\": \"2024-07-07 16:00:00\",\n",
" \"postsocial_end\": \"2024-07-14 14:00:00\",\n",
" },\n",
" {\n",
" \"name\": \"social0.3-aeon4\",\n",
" \"presocial_start\": \"2024-06-08 19:00:00\",\n",
" \"presocial_end\": \"2024-06-17 14:00:00\",\n",
" \"social_start\": \"2024-06-19 12:00:00\",\n",
" \"social_end\": \"2024-07-03 14:00:00\",\n",
" \"postsocial_start\": \"2024-07-04 11:00:00\",\n",
" \"postsocial_end\": \"2024-07-13 12:00:00\",\n",
" },\n",
" {\n",
" \"name\": \"social0.4-aeon3\",\n",
" \"presocial_start\": \"2024-08-16 17:00:00\",\n",
" \"presocial_end\": \"2024-08-24 10:00:00\",\n",
" \"social_start\": \"2024-08-28 11:00:00\",\n",
" \"social_end\": \"2024-09-09 13:00:00\",\n",
" \"postsocial_start\": \"2024-09-09 18:00:00\",\n",
" \"postsocial_end\": \"2024-09-22 16:00:00\",\n",
" },\n",
" {\n",
" \"name\": \"social0.4-aeon4\",\n",
" \"presocial_start\": \"2024-08-16 15:00:00\",\n",
" \"presocial_end\": \"2024-08-24 10:00:00\",\n",
" \"social_start\": \"2024-08-28 10:00:00\",\n",
" \"social_end\": \"2024-09-09 01:00:00\",\n",
" \"postsocial_start\": \"2024-09-09 15:00:00\",\n",
" \"postsocial_end\": \"2024-09-22 16:00:00\",\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_data_to_parquet(\n",
" df: pd.DataFrame,\n",
" experiment_name: str,\n",
" period_name: str,\n",
" data_type: str,\n",
" data_dir: Path\n",
") -> Path:\n",
" \"\"\"Saves any DataFrame to a parquet file with consistent naming and metadata.\n",
"\n",
" Args:\n",
" df (pd.DataFrame): Data to save\n",
" experiment_name (str): Name of the experiment\n",
" period_name (str): Period name (presocial, social, postsocial)\n",
" data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)\n",
" data_dir (Path): Directory to save the file\n",
"\n",
" Returns:\n",
" Path: Path to the saved file\n",
" \"\"\"\n",
" # Create directory if it doesn't exist\n",
" os.makedirs(data_dir, exist_ok=True)\n",
"\n",
" # Add period column for reference if not already present\n",
" df = df.copy()\n",
" if 'period' not in df.columns:\n",
" df['period'] = period_name\n",
"\n",
" # Handle index properly for consistent loading\n",
" if df.index.name and df.index.name != 'time':\n",
" df = df.reset_index()\n",
"\n",
" # Create filename\n",
" filename = f\"{experiment_name}_{period_name}_{data_type}.parquet\"\n",
" file_path = data_dir / filename\n",
"\n",
" print(f\" Saving to {file_path}...\")\n",
" # Save to parquet with compression\n",
" df.to_parquet(file_path, compression=\"snappy\")\n",
"\n",
" # Report file stats\n",
" file_size_mb = os.path.getsize(file_path) / (1024 * 1024)\n",
" memory_usage_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)\n",
" print(f\" Saved successfully: {len(df)} rows, {file_size_mb:.2f} MB on disk\")\n",
"\n",
" return file_path\n",
"\n",
"\n",
"def load_data_from_parquet(\n",
" experiment_name: str | None,\n",
" period: str | None,\n",
" data_type: str,\n",
" data_dir: Path,\n",
" set_time_index: bool = False\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads saved data from parquet files.\n",
"\n",
" Args:\n",
" experiment_name (str, optional): Filter by experiment name. If None, load all experiments.\n",
" period (str, optional): Filter by period (presocial, social, postsocial). If None, load all periods.\n",
" data_type (str): Type of data to load (position, patch, foraging, rfid, sleep, explore)\n",
" data_dir (Path): Directory containing parquet files.\n",
" set_time_index (bool, optional): If True, set 'time' column as DataFrame index.\n",
"\n",
" Returns:\n",
" pd.DataFrame: Combined DataFrame of all matching parquet files.\n",
" \"\"\"\n",
" if not data_dir.exists():\n",
" print(f\"Directory {data_dir} does not exist. No data files found.\")\n",
" return pd.DataFrame()\n",
"\n",
" # Create pattern based on filters\n",
" pattern = \"\"\n",
" if experiment_name:\n",
" pattern += f\"{experiment_name}_\"\n",
" else:\n",
" pattern += \"*_\"\n",
"\n",
" if period:\n",
" pattern += f\"{period}_\"\n",
" else:\n",
" pattern += \"*_\"\n",
"\n",
" pattern += f\"{data_type}.parquet\"\n",
"\n",
" # Find matching files\n",
" matching_files = list(data_dir.glob(pattern))\n",
"\n",
" if not matching_files:\n",
" print(f\"No matching data files found with pattern: {pattern}\")\n",
" return pd.DataFrame()\n",
"\n",
" print(f\"Found {len(matching_files)} matching files\")\n",
"\n",
" # Load and concatenate matching files\n",
" dfs = []\n",
" total_rows = 0\n",
" for file in matching_files:\n",
" print(f\"Loading {file}...\")\n",
" df = pd.read_parquet(file)\n",
" total_rows += len(df)\n",
" dfs.append(df)\n",
" print(f\" Loaded {len(df)} rows\")\n",
"\n",
" # Combine data\n",
" if dfs:\n",
" combined_df = pd.concat(dfs, ignore_index=True)\n",
" if set_time_index and 'time' in combined_df.columns:\n",
" combined_df = combined_df.set_index('time')\n",
" print(f\"Combined data: {len(combined_df)} rows\")\n",
" return combined_df\n",
" else:\n",
" return pd.DataFrame()\n",
"\n",
"\n",
"def save_all_experiment_data(\n",
" experiments: list,\n",
" periods: list,\n",
" data_dict: dict,\n",
" data_type: str,\n",
" data_dir: Path\n",
") -> None:\n",
" \"\"\"Save data for all experiments and periods in a standardized way.\n",
" \n",
" Args:\n",
" experiments (list): List of experiment dictionaries with 'name' field\n",
" data_dict (dict): Nested dictionary with structure {exp_name: {period_name: dataframe}}\n",
" data_type (str): Type of data (position, patch, foraging, rfid, sleep, explore)\n",
" data_dir (Path): Directory to save files\n",
" periods (list): List of periods to process\n",
" \"\"\"\n",
" # Save individual experiment data\n",
" for exp in experiments:\n",
" for period in periods:\n",
" df = data_dict[exp['name']][period]\n",
" if isinstance(df, pd.DataFrame) and not df.empty:\n",
" save_data_to_parquet(\n",
" df,\n",
" exp['name'],\n",
" period,\n",
" data_type,\n",
" data_dir\n",
" )\n",
"\n",
"\n",
"def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:\n",
" \"\"\"Excises swaps in the position data.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data of a single subject.\n",
" max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame with swaps excised.\n",
" \"\"\"\n",
" dt = pos_df.index.diff().total_seconds()\n",
" dx = pos_df[\"x\"].diff()\n",
" dy = pos_df[\"y\"].diff()\n",
" pos_df[\"inst_speed\"] = np.sqrt(dx**2 + dy**2) / dt\n",
"\n",
" # Identify jumps\n",
" jumps = (pos_df[\"inst_speed\"] > max_speed)\n",
" shift_down = jumps.shift(1)\n",
" shift_down.iloc[0] = False\n",
" shift_up = jumps.shift(-1)\n",
" shift_up.iloc[len(jumps) - 1] = False\n",
" jump_starts = jumps & ~shift_down\n",
" jump_ends = jumps & ~shift_up\n",
" jump_start_indices = np.where(jump_starts)[0]\n",
" jump_end_indices = np.where(jump_ends)[0]\n",
"\n",
" if np.any(jumps):\n",
"\n",
" # Ensure the lengths match\n",
" if len(jump_start_indices) > len(jump_end_indices): # jump-in-progress at start\n",
" jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)\n",
" elif len(jump_start_indices) < len(jump_end_indices): # jump-in-progress at end\n",
" jump_start_indices = np.insert(jump_start_indices, 0, 0)\n",
"\n",
" # Excise jumps by setting speed to nan in jump regions and dropping nans\n",
" for start, end in zip(jump_start_indices, jump_end_indices, strict=True):\n",
" pos_df.loc[pos_df.index[start]:pos_df.index[end], \"inst_speed\"] = np.nan\n",
" pos_df.dropna(subset=[\"inst_speed\"], inplace=True)\n",
"\n",
" return pos_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data loading and saving"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Patch data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_subject_patch_data(\n",
" key: dict[str, str], period_start: str, period_end: str\n",
") -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n",
" \"\"\"Loads subject patch data for a specified time period.\n",
"\n",
" Args:\n",
" key (dict): The key to filter the subject patch data.\n",
" period_start (str): The start time for the period.\n",
" period_end (str): The end time for the period.\n",
"\n",
" Returns:\n",
" tuple: A tuple containing:\n",
" - patch_info (pd.DataFrame): Information about patches.\n",
" - block_subject_patch_data (pd.DataFrame): Data for the specified period.\n",
" - block_subject_patch_pref (pd.DataFrame): Preference data for the specified period.\n",
" \"\"\"\n",
" patch_info = (\n",
" BlockAnalysis.Patch()\n",
" & key\n",
" & f\"block_start >= '{period_start}'\"\n",
" & f\"block_start <= '{period_end}'\"\n",
" ).fetch(\n",
" \"block_start\",\n",
" \"patch_name\",\n",
" \"patch_rate\",\n",
" \"patch_offset\",\n",
" \"wheel_timestamps\",\n",
" as_dict=True,\n",
" )\n",
"\n",
" block_subject_patch_data = (\n",
" BlockSubjectAnalysis.Patch()\n",
" & key\n",
" & f\"block_start >= '{period_start}'\"\n",
" & f\"block_start <= '{period_end}'\"\n",
" ).fetch(format=\"frame\")\n",
"\n",
" block_subject_patch_pref = (\n",
" BlockSubjectAnalysis.Preference()\n",
" & key\n",
" & f\"block_start >= '{period_start}'\"\n",
" & f\"block_start <= '{period_end}'\"\n",
" ).fetch(format=\"frame\")\n",
"\n",
" if patch_info:\n",
" patch_info = pd.DataFrame(patch_info)\n",
"\n",
" if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:\n",
" block_subject_patch_data.reset_index(inplace=True)\n",
"\n",
" if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:\n",
" block_subject_patch_pref.reset_index(inplace=True)\n",
"\n",
" return patch_info, block_subject_patch_data, block_subject_patch_pref"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def ensure_ts_arr_datetime(array):\n",
" if len(array) == 0:\n",
" return np.array([], dtype=\"datetime64[ns]\")\n",
" else:\n",
" return np.array(array, dtype=\"datetime64[ns]\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"patch_info_dict = {}\n",
"subject_patch_data_dict = {}\n",
"subject_patch_pref_dict = {}\n",
"\n",
"for exp in experiments:\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # Define periods\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
"\n",
" # Initialize nested dictionaries for this experiment\n",
" patch_info_dict[exp[\"name\"]] = {}\n",
" subject_patch_data_dict[exp[\"name\"]] = {}\n",
" subject_patch_pref_dict[exp[\"name\"]] = {}\n",
"\n",
" # Load data for each period\n",
" for period_name, (period_start, period_end) in periods.items():\n",
" period_start = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # Load data for this period\n",
" patch_info, block_subject_patch_data, block_subject_patch_pref = (\n",
" load_subject_patch_data(key, period_start, period_end)\n",
" )\n",
"\n",
" # Drop nans for 'final_preference' columns\n",
" block_subject_patch_pref = block_subject_patch_pref.dropna(\n",
" subset=[\"final_preference_by_time\", \"final_preference_by_wheel\"]\n",
" )\n",
"\n",
" # Extra processing on patch_info\n",
" if isinstance(patch_info, pd.DataFrame) and not patch_info.empty:\n",
" # Add experiment_name and period columns for reference\n",
" patch_info.insert(0, \"experiment_name\", exp[\"name\"])\n",
" patch_info.insert(1, \"period\", period_name)\n",
"\n",
" # Extra processing on block_subject_patch_data\n",
" if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:\n",
" # Add period column for reference\n",
" block_subject_patch_data.insert(1, \"period\", period_name)\n",
"\n",
" # # Remove dummy patches\n",
" # block_subject_patch_data = block_subject_patch_data[\n",
" # ~block_subject_patch_data[\"patch_name\"].str.contains(\"PatchDummy\")\n",
" # ]\n",
"\n",
" # For pre-social and post-social periods check n_subjects per block (should == 1)\n",
" if period_name in [\"presocial\", \"postsocial\"]:\n",
" n_subjects = block_subject_patch_data.groupby(\"block_start\")[\n",
" \"subject_name\"\n",
" ].nunique()\n",
" if (n_subjects != 1).any():\n",
" warnings.warn(\n",
" f\"Pre or post social data for {exp['name']} has blocks with more than one \"\n",
" f\"subject being tracked. Data needs to be fixed or cleaned.\"\n",
" )\n",
"\n",
" if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:\n",
" # Add period column for reference\n",
" block_subject_patch_pref.insert(1, \"period\", period_name)\n",
"\n",
" # # Remove dummy patches\n",
" # block_subject_patch_pref = block_subject_patch_pref[\n",
" # ~block_subject_patch_pref[\"patch_name\"].str.contains(\"PatchDummy\")\n",
" # ]\n",
"\n",
" # Ensure timestamps are correct type (datetime64[ns])\n",
" if \"pellet_timestamps\" in block_subject_patch_data.columns:\n",
" block_subject_patch_data[\"pellet_timestamps\"] = block_subject_patch_data[\n",
" \"pellet_timestamps\"\n",
" ].apply(ensure_ts_arr_datetime)\n",
"\n",
" if \"in_patch_rfid_timestamps\" in block_subject_patch_data.columns:\n",
" block_subject_patch_data[\"in_patch_rfid_timestamps\"] = (\n",
" block_subject_patch_data[\n",
" \"in_patch_rfid_timestamps\"\n",
" ].apply(ensure_ts_arr_datetime)\n",
" )\n",
"\n",
" if \"in_patch_timestamps\" in block_subject_patch_data.columns:\n",
" block_subject_patch_data[\"in_patch_timestamps\"] = (\n",
" block_subject_patch_data[\n",
" \"in_patch_timestamps\"\n",
" ].apply(ensure_ts_arr_datetime)\n",
" )\n",
"\n",
" # Store the data (patch_info as DataFrame now)\n",
" patch_info_dict[exp[\"name\"]][period_name] = patch_info\n",
" subject_patch_data_dict[exp[\"name\"]][period_name] = block_subject_patch_data\n",
" subject_patch_pref_dict[exp[\"name\"]][period_name] = block_subject_patch_pref"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"display(patch_info)\n",
"display(block_subject_patch_data)\n",
"display(block_subject_patch_pref)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the data to parquet files\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=patch_info_dict,\n",
" data_type=\"patchinfo\",\n",
" data_dir=data_dir,\n",
")\n",
"\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=subject_patch_data_dict,\n",
" data_type=\"patch\",\n",
" data_dir=data_dir,\n",
")\n",
"\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=subject_patch_pref_dict,\n",
" data_type=\"patchpref\",\n",
" data_dir=data_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the data from parquet files\n",
"combined_social_patch_info_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"patchinfo\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_patch_info_df)\n",
"\n",
"combined_social_patch_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"patch\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_patch_df)\n",
"\n",
"combined_social_patch_pref_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"patchpref\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_patch_pref_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Foraging bouts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_foraging_bouts(\n",
" key: Dict[str, str], period_start: str, period_end: str\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads foraging bout data for blocks falling within a specified time period.\n",
"\n",
" Args:\n",
" key (dict): Key to identify experiment data (e.g., {\"experiment_name\": \"Exp1\"}).\n",
" period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
" period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
"\n",
" Returns:\n",
" pd.DataFrame: Concatenated dataframe of foraging bouts for all matching blocks.\n",
" Returns an empty dataframe with predefined columns if no data found.\n",
" \"\"\"\n",
" # Fetch block start times within the specified period\n",
" blocks = (\n",
" Block & key & f\"block_start >= '{period_start}'\" & f\"block_end <= '{period_end}'\"\n",
" ).fetch(\"block_start\")\n",
"\n",
" # Retrieve foraging bouts for each block\n",
" bouts = []\n",
" for block_start in blocks:\n",
" block_key = key | {\"block_start\": str(block_start)}\n",
" bouts.append(get_foraging_bouts(block_key, min_pellets=1))\n",
"\n",
" # Return concatenated DataFrame or empty fallback\n",
" if bouts:\n",
" return pd.concat(bouts, ignore_index=True)\n",
" else:\n",
" return pd.DataFrame(\n",
" columns=[\"start\", \"end\", \"n_pellets\", \"cum_wheel_dist\", \"subject\"]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a dictionary to hold foraging data for each experiment and period\n",
"foraging_data_dict = {}\n",
"\n",
"for exp in experiments:\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # Initialize nested dictionary for this experiment\n",
" foraging_data_dict[exp[\"name\"]] = {}\n",
"\n",
" # Define periods\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
"\n",
" # Load data for each period\n",
" for period_name, (period_start, period_end) in periods.items():\n",
" period_start = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # Load foraging data for this period\n",
" foraging_df = load_foraging_bouts(key, period_start, period_end)\n",
"\n",
" # Add experiment name as a column if not already present\n",
" if \"experiment_name\" not in foraging_df.columns:\n",
" foraging_df.insert(0, \"experiment_name\", exp[\"name\"])\n",
"\n",
" # Add period column for reference\n",
" foraging_df[\"period\"] = period_name\n",
"\n",
" # Store the data\n",
" foraging_data_dict[exp[\"name\"]][period_name] = foraging_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the foraging data to parquet files\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=foraging_data_dict,\n",
" data_type=\"foraging\",\n",
" data_dir=data_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the combined social period foraging data\n",
"combined_social_foraging_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"foraging\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_foraging_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RFID data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_rfid_events(\n",
" key: Dict[str, str], period_start: str, period_end: str\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads RFID events data for chunks falling within a specified time period.\n",
"\n",
" Args:\n",
" key (dict): Key to identify experiment data (e.g., {\"experiment_name\": \"Exp1\"}).\n",
" period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
" period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing RFID events for the specified period.\n",
" Returns an empty dataframe with predefined columns if no data found.\n",
" \"\"\"\n",
" # Fetch RFID events within the specified period\n",
" rfid_events_df = (\n",
" streams.RfidReader * streams.RfidReaderRfidEvents\n",
" & key\n",
" & f'chunk_start >= \"{period_start}\"'\n",
" & f'chunk_start <= \"{period_end}\"'\n",
" ).fetch(format=\"frame\")\n",
"\n",
" if rfid_events_df.empty or not isinstance(rfid_events_df, pd.DataFrame):\n",
" # Return empty DataFrame with expected columns if no data found\n",
" return pd.DataFrame(\n",
" columns=[\n",
" \"experiment_name\",\n",
" \"chunk_start\",\n",
" \"rfid_reader_name\",\n",
" \"sample_count\",\n",
" \"timestamps\",\n",
" \"rfid\",\n",
" ]\n",
" )\n",
"\n",
" # Get subject details for RFID mapping\n",
" subject_detail = subject.SubjectDetail.fetch(format=\"frame\")\n",
" subject_detail.reset_index(inplace=True)\n",
"\n",
" # Create mapping from RFID to subject ID\n",
" rfid_to_lab_id = dict(zip(subject_detail[\"lab_id\"], subject_detail[\"subject\"]))\n",
"\n",
" rfid_events_df[\"rfid\"] = [\n",
" [rfid_to_lab_id.get(str(rfid)) for rfid in rfid_array]\n",
" for rfid_array in rfid_events_df[\"rfid\"]\n",
" ]\n",
"\n",
" # Extract experiment_name and chunk_start from the index before resetting\n",
" rfid_events_df[\"experiment_name\"] = [idx[0] for idx in rfid_events_df.index]\n",
" rfid_events_df[\"chunk_start\"] = [\n",
" idx[3] for idx in rfid_events_df.index\n",
" ] # Assuming chunk_start is at index 3\n",
"\n",
" # Reset the index and drop the index column\n",
" rfid_events_df = rfid_events_df.reset_index(drop=True)\n",
"\n",
" # Reorder columns to put experiment_name first and chunk_start second\n",
" cols = [\"experiment_name\", \"chunk_start\"] + [\n",
" col\n",
" for col in rfid_events_df.columns\n",
" if col not in [\"experiment_name\", \"chunk_start\"]\n",
" ]\n",
" rfid_events_df = rfid_events_df[cols]\n",
"\n",
" return rfid_events_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a dictionary to hold RFID data for each experiment and period\n",
"rfid_data_dict = {}\n",
"\n",
"for exp in experiments:\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # Initialize nested dictionary for this experiment\n",
" rfid_data_dict[exp[\"name\"]] = {}\n",
"\n",
" # Define periods\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
"\n",
" # Load data for each period\n",
" for period_name, (period_start, period_end) in periods.items():\n",
" period_start_str = period_start\n",
" period_end_str = period_end\n",
"\n",
" # Handle datetime objects if needed\n",
" if not isinstance(period_start, str):\n",
" period_start_str = period_start.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
" if not isinstance(period_end, str):\n",
" period_end_str = period_end.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # Load RFID data for this period\n",
" rfid_df = load_rfid_events(key, period_start_str, period_end_str)\n",
"\n",
" # Add experiment name as a column if not already present\n",
" if \"experiment_name\" not in rfid_df.columns:\n",
" rfid_df.insert(0, \"experiment_name\", exp[\"name\"])\n",
"\n",
" # Add period column for reference\n",
" rfid_df[\"period\"] = period_name\n",
"\n",
" # Store the data\n",
" rfid_data_dict[exp[\"name\"]][period_name] = rfid_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the RFID data to parquet files\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=rfid_data_dict,\n",
" data_type=\"rfid\",\n",
" data_dir=data_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the combined social period RFID data\n",
"combined_social_rfid_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"rfid\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_rfid_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Position data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_position_data(\n",
" key: Dict[str, str], period_start: str, period_end: str\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads position data (centroid tracking) for a specified time period.\n",
"\n",
" Args:\n",
" key (dict): Key to identify experiment data (e.g., {\"experiment_name\": \"Exp1\"}).\n",
" period_start (str): Start datetime of the time period.\n",
" period_end (str): End datetime of the time period.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing position data for the specified period.\n",
" Returns an empty DataFrame if no data found.\n",
" \"\"\"\n",
" try:\n",
" print(f\" Querying data from {period_start} to {period_end}...\")\n",
"\n",
" # Create chunk restriction for the time period\n",
" chunk_restriction = acquisition.create_chunk_restriction(\n",
" key[\"experiment_name\"], period_start, period_end\n",
" )\n",
"\n",
" # Fetch centroid tracking data for the specified period\n",
" centroid_df = (\n",
" streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject\n",
" & key\n",
" & {\"spinnaker_video_source_name\": \"CameraTop\"}\n",
" & chunk_restriction\n",
" ).fetch(format=\"frame\")\n",
"\n",
" centroid_df = centroid_df.reset_index()\n",
" centroid_df = centroid_df.rename(\n",
" columns={\n",
" \"subject_name\": \"identity_name\",\n",
" \"timestamps\": \"time\",\n",
" \"subject_likelihood\": \"identity_likelihood\",\n",
" }\n",
" )\n",
" centroid_df = centroid_df.explode(\n",
" [\"time\", \"identity_likelihood\", \"x\", \"y\", \"likelihood\"]\n",
" )\n",
" centroid_df = centroid_df[\n",
" [\n",
" \"time\",\n",
" \"experiment_name\",\n",
" \"identity_name\",\n",
" \"identity_likelihood\",\n",
" \"x\",\n",
" \"y\",\n",
" \"likelihood\",\n",
" ]\n",
" ].set_index(\"time\")\n",
"\n",
" # Clean up the dataframe\n",
" if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:\n",
" if \"spinnaker_video_source_name\" in centroid_df.columns:\n",
" centroid_df.drop(columns=[\"spinnaker_video_source_name\"], inplace=True)\n",
" print(f\" Retrieved {len(centroid_df)} rows of position data\")\n",
" else:\n",
" print(\" No data found for the specified period\")\n",
"\n",
" return centroid_df\n",
"\n",
" except Exception as e:\n",
" print(\n",
" f\" Error loading position data for {key['experiment_name']} ({period_start} \"\n",
" f\"to {period_end}): {e}\"\n",
" )\n",
" return pd.DataFrame()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a dictionary to hold position data for each experiment and period\n",
"position_data_dict = {}\n",
"\n",
"for exp in experiments:\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # Initialize nested dictionary for this experiment\n",
" position_data_dict[exp[\"name\"]] = {}\n",
"\n",
" # Define periods\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
"\n",
" # Load data for each period\n",
" for period_name, (period_start, period_end) in periods.items():\n",
" print(f\" Loading {period_name} period...\")\n",
"\n",
" period_start = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # Load position data for this period\n",
" position_df = load_position_data(key, period_start, period_end)\n",
" position_df.reset_index(inplace=True)\n",
"\n",
" # Add period column for reference if not empty\n",
" if isinstance(position_df, pd.DataFrame) and not position_df.empty:\n",
" position_df[\"period\"] = period_name\n",
"\n",
" # Store the data\n",
" position_data_dict[exp[\"name\"]][period_name] = position_df\n",
"\n",
" # Print data size info\n",
" if isinstance(position_df, pd.DataFrame) and not position_df.empty:\n",
" memory_usage_mb = position_df.memory_usage(deep=True).sum() / (1024 * 1024)\n",
" print(\n",
" f\" {period_name}: {len(position_df)} rows, {memory_usage_mb:.2f} MB in memory\"\n",
" )\n",
" else:\n",
" print(f\" {period_name}: No data available\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the position data to parquet files\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=position_data_dict,\n",
" data_type=\"position\",\n",
" data_dir=data_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the combined social period position data\n",
"combined_social_position_df = load_data_from_parquet(\n",
" experiment_name=\"social0.4-aeon3\",\n",
" period=\"social\",\n",
" data_type=\"position\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
")\n",
"display(combined_social_position_df.sort_index())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weight data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_weight_data(\n",
" key: Dict[str, str], period_start: str, period_end: str\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads weight data for a specified time period.\n",
"\n",
" Args:\n",
" key (dict): Key to identify experiment data (e.g., {\"experiment_name\": \"Exp1\"}).\n",
" period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
" period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').\n",
"\n",
" Returns:\n",
" pd.DataFrame: Weight data for the specified period.\n",
" Returns an empty dataframe if no data found.\n",
" \"\"\"\n",
" try:\n",
" weight_df = (\n",
" acquisition.Environment.SubjectWeight\n",
" & key\n",
" & f\"chunk_start >= '{period_start}'\"\n",
" & f\"chunk_start <= '{period_end}'\"\n",
" ).fetch(format=\"frame\")\n",
"\n",
" return weight_df if not weight_df.empty and isinstance(weight_df, pd.DataFrame) else pd.DataFrame()\n",
" except Exception as e:\n",
" print(\n",
" f\"Error loading weight data for {key} from {period_start} to {period_end}: {e}\"\n",
" )\n",
" return pd.DataFrame()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a dictionary to hold weight data for each experiment and period\n",
"weight_data_dict = {}\n",
"\n",
"for exp in experiments:\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # Initialize nested dictionary for this experiment\n",
" weight_data_dict[exp[\"name\"]] = {}\n",
"\n",
" # Define periods\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
"\n",
" # Load data for each period\n",
" for period_name, (period_start, period_end) in periods.items():\n",
" # Convert to datetime if needed (assuming they're already strings in the right format)\n",
" if isinstance(period_start, str):\n",
" period_start_dt = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end_dt = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
" else:\n",
" period_start_dt = period_start\n",
" period_end_dt = period_end\n",
"\n",
" # Load weight data for this period\n",
" weight_df = load_weight_data(key, str(period_start_dt), str(period_end_dt))\n",
"\n",
" # Add experiment name as a column if not already present and data exists\n",
" if isinstance(weight_df, pd.DataFrame) and not weight_df.empty:\n",
" if \"experiment_name\" not in weight_df.columns:\n",
" weight_df.insert(0, \"experiment_name\", exp[\"name\"])\n",
"\n",
" # Add period column for reference\n",
" weight_df[\"period\"] = period_name\n",
"\n",
" # Store the data\n",
" weight_data_dict[exp[\"name\"]][period_name] = weight_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the weight data to parquet files\n",
"save_all_experiment_data(\n",
" experiments=experiments,\n",
" periods=[\"presocial\", \"social\", \"postsocial\"],\n",
" data_dict=weight_data_dict,\n",
" data_type=\"weight\",\n",
" data_dir=data_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the combined social period weight data\n",
"combined_social_weight_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\", # or whichever experiment you want\n",
" period=\"social\",\n",
" data_type=\"weight\",\n",
" data_dir=data_dir,\n",
")\n",
"display(combined_social_weight_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sleep bouts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def excise_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:\n",
" \"\"\"Excises swaps in the position data.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data of a single subject.\n",
" max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame with swaps excised.\n",
" \"\"\"\n",
" dt = pos_df.index.diff().total_seconds()\n",
" dx = pos_df[\"x\"].diff()\n",
" dy = pos_df[\"y\"].diff()\n",
" pos_df[\"inst_speed\"] = np.sqrt(dx**2 + dy**2) / dt\n",
"\n",
" # Identify jumps\n",
" jumps = pos_df[\"inst_speed\"] > max_speed\n",
" shift_down = jumps.shift(1)\n",
" shift_down.iloc[0] = False\n",
" shift_up = jumps.shift(-1)\n",
" shift_up.iloc[len(jumps) - 1] = False\n",
" jump_starts = jumps & ~shift_down\n",
" jump_ends = jumps & ~shift_up\n",
" jump_start_indices = np.where(jump_starts)[0]\n",
" jump_end_indices = np.where(jump_ends)[0]\n",
"\n",
" if np.any(jumps):\n",
" # Ensure the lengths match\n",
" if len(jump_start_indices) > len(jump_end_indices): # jump-in-progress at start\n",
" jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)\n",
" elif len(jump_start_indices) < len(jump_end_indices): # jump-in-progress at end\n",
" jump_start_indices = np.insert(jump_start_indices, 0, 0)\n",
"\n",
" # Excise jumps by setting speed to nan in jump regions and dropping nans\n",
" for start, end in zip(jump_start_indices, jump_end_indices, strict=True):\n",
" pos_df.loc[pos_df.index[start] : pos_df.index[end], \"inst_speed\"] = np.nan\n",
" pos_df.dropna(subset=[\"inst_speed\"], inplace=True)\n",
"\n",
" return pos_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Given pos_df and animal name, reutrn all sleep bouts in df within the pos_df time period\n",
"\n",
"\n",
"def sleep_bouts(\n",
" pos_df: pd.DataFrame,\n",
" subject: str,\n",
" move_thresh: float = 4 * cm2px, # cm -> px\n",
" max_speed: float = 100 * cm2px, # cm/s -> px/s\n",
") -> pd.DataFrame:\n",
" \"\"\"Returns sleep bouts for a given animal within the specified position data time period.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data.\n",
" subject (str): Name of the animal to filter by.\n",
" move_thresh (float): Movement (in px) threshold to define sleep bouts.\n",
" max_speed (float): Maximum speed threshold for excising swaps.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing sleep bouts for the specified animal.\n",
" \"\"\"\n",
" animal_data = pos_df[pos_df[\"identity_name\"] == subject].copy()\n",
" if animal_data.empty or not isinstance(animal_data, pd.DataFrame):\n",
" print(f\"No position data found for {subject}\")\n",
" return pd.DataFrame()\n",
"\n",
" # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`\n",
" sleep_win = pd.Timedelta(\"1m\")\n",
" sleep_windows_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
"\n",
" # Create time windows based on start and end time\n",
" data_start_time = animal_data.index.min()\n",
" data_end_time = animal_data.index.max()\n",
" window_starts = pd.date_range(\n",
" start=data_start_time, end=data_end_time, freq=sleep_win\n",
" )\n",
"\n",
" # Process each time window\n",
" period = animal_data[\"period\"].iloc[0]\n",
" pbar = tqdm(window_starts, desc=f\"Processing sleep bouts for {subject} in {period}\")\n",
" for win_start in pbar:\n",
" win_end = win_start + sleep_win\n",
" win_data = animal_data[\n",
" (animal_data.index >= win_start) & (animal_data.index < win_end)\n",
" ].copy()\n",
" if len(win_data) < 100: # skip windows with too little data\n",
" continue\n",
"\n",
" # Excise id swaps (based on pos / speed jumps)\n",
" # win_data = excise_swaps(win_data, max_speed)\n",
"\n",
" # Calculate the displacement - maximum distance between any two points in the window\n",
" dx = win_data[\"x\"].max() - win_data[\"x\"].min()\n",
" dy = win_data[\"y\"].max() - win_data[\"y\"].min()\n",
" displacement = np.sqrt(dx**2 + dy**2)\n",
"\n",
" # If displacement is less than threshold, consider it a sleep bout\n",
" if displacement < move_thresh:\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": win_start,\n",
" \"end\": win_end,\n",
" \"duration\": sleep_win,\n",
" \"period\": win_data[\"period\"].iloc[0],\n",
" }\n",
" sleep_windows_df = pd.concat(\n",
" [sleep_windows_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" # Now merge consecutive sleep windows into continuous bouts\n",
" if sleep_windows_df.empty or not isinstance(sleep_windows_df, pd.DataFrame):\n",
" return pd.DataFrame(columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"])\n",
" # Initialize the merged bouts dataframe with the first window\n",
" sleep_bouts_df = pd.DataFrame(\n",
" [\n",
" {\n",
" \"subject\": subject,\n",
" \"start\": sleep_windows_df.iloc[0][\"start\"],\n",
" \"end\": sleep_windows_df.iloc[0][\"end\"],\n",
" \"duration\": sleep_windows_df.iloc[0][\"duration\"],\n",
" \"period\": sleep_windows_df.iloc[0][\"period\"],\n",
" }\n",
" ]\n",
" )\n",
" # Iterate through remaining windows and merge consecutive ones\n",
" for i in range(1, len(sleep_windows_df)):\n",
" current_window = sleep_windows_df.iloc[i]\n",
" last_bout = sleep_bouts_df.iloc[-1]\n",
"\n",
" if current_window[\"start\"] == last_bout[\"end\"]: # continue bout\n",
" sleep_bouts_df.at[len(sleep_bouts_df) - 1, \"end\"] = current_window[\"end\"]\n",
" sleep_bouts_df.at[len(sleep_bouts_df) - 1, \"duration\"] = (\n",
" sleep_bouts_df.iloc[-1][\"end\"] - sleep_bouts_df.iloc[-1][\"start\"]\n",
" )\n",
" else: # start a new bout\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": current_window[\"start\"],\n",
" \"end\": current_window[\"end\"],\n",
" \"duration\": current_window[\"duration\"],\n",
" \"period\": current_window[\"period\"],\n",
" }\n",
" sleep_bouts_df = pd.concat(\n",
" [sleep_bouts_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" # Set min bout time\n",
" min_bout_time = pd.Timedelta(\"2m\")\n",
" sleep_bouts_df = sleep_bouts_df[sleep_bouts_df[\"duration\"] >= min_bout_time]\n",
"\n",
" return sleep_bouts_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Save sleep bouts to parquet files for all experiments and periods.\"\"\"\n",
"\n",
"# For each experiment, for each period, load pos data, get sleep bouts, save to parquet\n",
"\n",
"pbar_exp = tqdm(experiments, desc=\"Processing experiments\")\n",
"for exp in pbar_exp:\n",
" sleep_bouts_data_dict = {}\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
" sleep_bouts_data_dict[exp[\"name\"]] = {}\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
" pbar_period = tqdm(periods.items(), desc=\"Processing periods\", leave=False)\n",
" for period_name, (period_start, period_end) in pbar_period:\n",
" print(f\" Loading {period_name} period...\")\n",
" period_start = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # load pos data for this period\n",
" pos_df = load_data_from_parquet(\n",
" experiment_name=exp[\"name\"],\n",
" period=period_name,\n",
" data_type=\"position\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
" )\n",
"\n",
" # get sleep bouts for each subject\n",
" subjects = pos_df[\"identity_name\"].unique()\n",
" sleep_bouts_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
" for subject in subjects:\n",
" subject_bouts = sleep_bouts(pos_df, subject)\n",
" if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:\n",
" sleep_bouts_df = pd.concat(\n",
" [sleep_bouts_df, subject_bouts], ignore_index=True\n",
" )\n",
"\n",
" # save data dict\n",
" sleep_bouts_data_dict[exp[\"name\"]][period_name] = sleep_bouts_df\n",
" save_all_experiment_data(\n",
" experiments=[exp],\n",
" periods=[period_name],\n",
" data_dict=sleep_bouts_data_dict,\n",
" data_type=\"sleep\",\n",
" data_dir=data_dir,\n",
" )\n",
" print(f\" Saved sleep bouts for {exp['name']} during {period_name} period.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Example usage:\"\"\"\n",
"\n",
"sleep_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"presocial\",\n",
" data_type=\"sleep\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
")\n",
"display(sleep_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Drink bouts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def drink_bouts(\n",
" pos_df: pd.DataFrame,\n",
" subject: str,\n",
" spout_loc: tuple[float, float], # x,y spout location in px\n",
" start_radius: float = 4 * 5.2, # must be within X cm of spout, in px\n",
" move_thresh: float = 2.5 * 5.2, # during bout must move less than X cm, in px\n",
" min_dur: float = 6, # min duration of bout in seconds\n",
" max_dur: float = 90, # max duration of bout in seconds\n",
") -> pd.DataFrame: # cols: subject, start, end, duration, period\n",
" \"\"\"Returns drink bouts for a given animal within the specified position data time period.\"\"\"\n",
"\n",
" animal_data = pos_df[pos_df[\"identity_name\"] == subject].copy()\n",
" if animal_data.empty or not isinstance(animal_data, pd.DataFrame):\n",
" print(f\"No position data found for {subject}\")\n",
" return pd.DataFrame(columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"])\n",
"\n",
" # Smooth position data to 100ms intervals - only numeric columns\n",
" numeric_cols = animal_data.select_dtypes(include=[np.number]).columns\n",
" animal_data = animal_data[numeric_cols].resample(\"100ms\").mean().interpolate()\n",
" animal_data = animal_data.dropna()\n",
"\n",
" # Add non-numeric columns back\n",
" animal_data[\"identity_name\"] = subject\n",
" animal_data[\"experiment_name\"] = pos_df[\"experiment_name\"].iloc[0]\n",
" animal_data[\"period\"] = pos_df[\"period\"].iloc[0]\n",
"\n",
" # Calculate distance from spout\n",
" spout_x, spout_y = spout_loc\n",
" animal_data[\"dist_to_spout\"] = np.sqrt(\n",
" (animal_data[\"x\"] - spout_x) ** 2 + (animal_data[\"y\"] - spout_y) ** 2\n",
" )\n",
"\n",
" # Find potential bout starts (within start_radius of spout)\n",
" near_spout = animal_data[\"dist_to_spout\"] <= start_radius\n",
"\n",
" # Get period info\n",
" period = animal_data[\"period\"].iloc[0]\n",
"\n",
" drink_bouts_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
"\n",
" pbar = tqdm(\n",
" total=len(animal_data), desc=f\"Processing drink bouts for {subject} in {period}\"\n",
" )\n",
" i = 0\n",
" while i < len(animal_data):\n",
" pbar.update(i - (i - 1))\n",
" # Skip if not near spout\n",
" if not near_spout.iloc[i]:\n",
" i += 1\n",
" continue\n",
"\n",
" # Found potential bout start\n",
" bout_start_time = animal_data.index[i]\n",
" bout_start_idx = i\n",
"\n",
" # Track movement during potential bout\n",
" start_x = animal_data[\"x\"].iloc[i]\n",
" start_y = animal_data[\"y\"].iloc[i]\n",
"\n",
" j = i\n",
" max_displacement = 0\n",
"\n",
" # Continue while near spout and not moving too much\n",
" while j < len(animal_data):\n",
" current_time = animal_data.index[j]\n",
" elapsed_time = (current_time - bout_start_time).total_seconds()\n",
"\n",
" # Calculate displacement from bout start position\n",
" current_x = animal_data[\"x\"].iloc[j]\n",
" current_y = animal_data[\"y\"].iloc[j]\n",
" displacement = np.sqrt(\n",
" (current_x - start_x) ** 2 + (current_y - start_y) ** 2\n",
" )\n",
" max_displacement = max(max_displacement, displacement)\n",
"\n",
" # Check if bout should end\n",
" if max_displacement > move_thresh:\n",
" break\n",
"\n",
" if elapsed_time > max_dur:\n",
" break\n",
"\n",
" j += 1\n",
"\n",
" # Determine bout end\n",
" bout_end_time = (\n",
" animal_data.index[j - 1] if j > bout_start_idx else bout_start_time\n",
" )\n",
" bout_duration = (bout_end_time - bout_start_time).total_seconds()\n",
"\n",
" # Check if bout meets duration criteria\n",
" if min_dur < bout_duration < max_dur:\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": bout_start_time,\n",
" \"end\": bout_end_time,\n",
" \"duration\": pd.Timedelta(seconds=bout_duration),\n",
" \"period\": period,\n",
" }\n",
" drink_bouts_df = pd.concat(\n",
" [drink_bouts_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
"\n",
" # Move to next potential bout (skip past current bout end)\n",
" i = max(j, i + 1)\n",
"\n",
" pbar.close()\n",
" return drink_bouts_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Save drink bouts to parquet files for all experiments and periods.\"\"\"\n",
"\n",
"# For each experiment, for each period, load pos data, get drink bouts, save to parquet\n",
"\n",
"pbar_exp = tqdm(experiments, desc=\"Processing experiments\")\n",
"for exp in pbar_exp:\n",
" drink_bouts_data_dict = {}\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
" drink_bouts_data_dict[exp[\"name\"]] = {}\n",
" pbar_period = tqdm(periods, desc=\"Processing periods\", leave=False)\n",
" for period_name in pbar_period:\n",
" print(f\" Loading {period_name} period...\")\n",
"\n",
" # load pos data for this period\n",
" pos_df = load_data_from_parquet(\n",
" experiment_name=exp[\"name\"],\n",
" period=period_name,\n",
" data_type=\"position\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
" )\n",
"\n",
" # get drink bouts for each subject\n",
" subjects = pos_df[\"identity_name\"].unique()\n",
" drink_bouts_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
" for subject in subjects:\n",
" spout_loc = (1280, 500) if \"aeon3\" in exp[\"name\"] else (1245, 535)\n",
" subject_bouts = drink_bouts(pos_df, subject, spout_loc)\n",
" if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:\n",
" drink_bouts_df = pd.concat(\n",
" [drink_bouts_df, subject_bouts], ignore_index=True\n",
" )\n",
"\n",
" # save data dict\n",
" drink_bouts_data_dict[exp[\"name\"]][period_name] = drink_bouts_df\n",
" save_all_experiment_data(\n",
" experiments=[exp],\n",
" periods=[period_name],\n",
" data_dict=drink_bouts_data_dict,\n",
" data_type=\"drink\",\n",
" data_dir=data_dir,\n",
" )\n",
" print(f\" Saved drink bouts for {exp['name']} during {period_name} period.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"drink_bouts_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explore bouts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Given pos_df, animal name, nest xy, reutrn all explore bouts in df\n",
"\n",
"nest_center = np.array((1215, 530))\n",
"cm2px = 5.2\n",
"nest_radius = 14 * cm2px # 14 cm, in px\n",
"\n",
"\n",
"def explore_bouts(\n",
" pos_df: pd.DataFrame,\n",
" subject: str,\n",
" nest_center: np.ndarray,\n",
" nest_radius: float = 14 * 5.2, # 14 cm, in px\n",
" max_speed: float = 100 * 5.2, # 100 cm/s, in px/s\n",
") -> pd.DataFrame:\n",
" \"\"\"Returns exploration bouts for a given animal within the specified position data time period.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data.\n",
" subject (str): Name of the animal to filter by.\n",
" nest_center (np.ndarray): Coordinates of the nest center.\n",
" nest_radius (float): Radius of the nest area (default: 14 cm in px).\n",
" max_speed (float): Maximum speed threshold for excising swaps (default: 100 cm/s in px/s).\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing exploration bouts for the specified animal.\n",
" \"\"\"\n",
" animal_data = pos_df[pos_df[\"identity_name\"] == subject].copy()\n",
" if animal_data.empty or not isinstance(animal_data, pd.DataFrame):\n",
" print(f\"No position data found for {subject}\")\n",
" return pd.DataFrame()\n",
"\n",
" # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`\n",
" explore_win = pd.Timedelta(\"1m\")\n",
" explore_windows_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
"\n",
" # Create time windows based on start and end time\n",
" data_start_time = animal_data.index.min()\n",
" data_end_time = animal_data.index.max()\n",
" window_starts = pd.date_range(\n",
" start=data_start_time, end=data_end_time, freq=explore_win\n",
" )\n",
"\n",
" # Process each time window (use tqdm for progress bar)\n",
" period = animal_data[\"period\"].iloc[0]\n",
" pbar = tqdm(window_starts, desc=f\"Processing explore bouts for {subject} in {period}\")\n",
" for win_start in pbar:\n",
" win_end = win_start + explore_win\n",
" win_data = animal_data[\n",
" (animal_data.index >= win_start) & (animal_data.index < win_end)\n",
" ].copy()\n",
" if len(win_data) < 100: # skip windows with too little data\n",
" continue\n",
"\n",
" # Excise id swaps (based on pos / speed jumps)\n",
" win_data = excise_swaps(win_data, max_speed)\n",
"\n",
" # If majority of time in a window is outside nest, consider it an explore bout\n",
" dx = win_data[\"x\"] - nest_center[0]\n",
" dy = win_data[\"y\"] - nest_center[1]\n",
" distance_from_nest = np.sqrt(dx**2 + dy**2)\n",
" frac_out_nest = (distance_from_nest > nest_radius).sum() / len(win_data)\n",
" if frac_out_nest > 0.5:\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": win_start,\n",
" \"end\": win_end,\n",
" \"duration\": explore_win,\n",
" \"period\": win_data[\"period\"].iloc[0],\n",
" }\n",
" explore_windows_df = pd.concat(\n",
" [explore_windows_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" # Now merge consecutive explore windows into continuous bouts\n",
" if explore_windows_df.empty or not isinstance(explore_windows_df, pd.DataFrame):\n",
" return pd.DataFrame(columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"])\n",
" # Initialize the merged bouts dataframe with the first window\n",
" explore_bouts_df = pd.DataFrame(\n",
" [\n",
" {\n",
" \"subject\": subject,\n",
" \"start\": explore_windows_df.iloc[0][\"start\"],\n",
" \"end\": explore_windows_df.iloc[0][\"end\"],\n",
" \"duration\": explore_windows_df.iloc[0][\"duration\"],\n",
" \"period\": explore_windows_df.iloc[0][\"period\"],\n",
" }\n",
" ]\n",
" )\n",
" # Iterate through remaining windows and merge consecutive ones\n",
" for i in range(1, len(explore_windows_df)):\n",
" current_window = explore_windows_df.iloc[i]\n",
" last_bout = explore_bouts_df.iloc[-1]\n",
"\n",
" if current_window[\"start\"] == last_bout[\"end\"]: # continue bout\n",
" explore_bouts_df.at[len(explore_bouts_df) - 1, \"end\"] = current_window[\"end\"]\n",
" explore_bouts_df.at[len(explore_bouts_df) - 1, \"duration\"] = (\n",
" explore_bouts_df.iloc[-1][\"end\"] - explore_bouts_df.iloc[-1][\"start\"]\n",
" )\n",
" else: # start a new bout\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": current_window[\"start\"],\n",
" \"end\": current_window[\"end\"],\n",
" \"duration\": current_window[\"duration\"],\n",
" \"period\": current_window[\"period\"],\n",
" }\n",
" explore_bouts_df = pd.concat(\n",
" [explore_bouts_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" return explore_bouts_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Save explore bouts to parquet files for each experiment and period\"\"\"\n",
"\n",
"# For each experiment, for each period, load pos data, get explore bouts, save to parquet\n",
"\n",
"pbar_exp = tqdm(experiments, desc=\"Processing experiments\")\n",
"for exp in pbar_exp:\n",
" sleep_bouts_data_dict = {}\n",
" key = {\"experiment_name\": exp[\"name\"]}\n",
"\n",
" # get nest center for this exp\n",
" epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj(\"epoch_start\")\n",
" active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query\n",
" roi_locs = dict(\n",
" zip(*active_region_query.fetch(\"region_name\", \"region_data\"), strict=True)\n",
" )\n",
" points = roi_locs[\"NestRegion\"][\"ArrayOfPoint\"]\n",
" vertices = np.array([[float(point[\"X\"]), float(point[\"Y\"])] for point in points])\n",
" nest_center = np.mean(vertices, axis=0)\n",
"\n",
" sleep_bouts_data_dict[exp[\"name\"]] = {}\n",
" periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
" }\n",
" pbar_period = tqdm(periods.items(), desc=\"Processing periods\", leave=False)\n",
" for period_name, (period_start, period_end) in pbar_period:\n",
" print(f\" Loading {period_name} period...\")\n",
" period_start = datetime.strptime(period_start, \"%Y-%m-%d %H:%M:%S\")\n",
" period_end = datetime.strptime(period_end, \"%Y-%m-%d %H:%M:%S\")\n",
"\n",
" # load pos data for this period\n",
" pos_df = load_data_from_parquet(\n",
" experiment_name=exp[\"name\"],\n",
" period=period_name,\n",
" data_type=\"position\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
" )\n",
"\n",
" # get explore bouts for each subject\n",
" subjects = pos_df[\"identity_name\"].unique()\n",
" sleep_bouts_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\", \"period\"]\n",
" )\n",
" for subject in subjects:\n",
" subject_bouts = explore_bouts(pos_df, subject, nest_center)\n",
" if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:\n",
" sleep_bouts_df = pd.concat(\n",
" [sleep_bouts_df, subject_bouts], ignore_index=True\n",
" )\n",
"\n",
" # save data dict\n",
" sleep_bouts_data_dict[exp[\"name\"]][period_name] = sleep_bouts_df\n",
" save_all_experiment_data(\n",
" experiments=[exp],\n",
" periods=[period_name],\n",
" data_dict=sleep_bouts_data_dict,\n",
" data_type=\"explore\",\n",
" data_dir=data_dir,\n",
" )\n",
" print(f\" Saved explore bouts for {exp['name']} during {period_name} period.\")\n",
"\n",
"\n",
"key = {\"experiment_name\": \"social0.2-aeon3\"}\n",
"epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj(\"epoch_start\")\n",
"active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query\n",
"roi_locs = dict(\n",
" zip(*active_region_query.fetch(\"region_name\", \"region_data\"), strict=True)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Example usage:\"\"\"\n",
"\n",
"explore_df = load_data_from_parquet(\n",
" experiment_name=\"social0.2-aeon3\",\n",
" period=\"presocial\",\n",
" data_type=\"explore\",\n",
" data_dir=data_dir,\n",
" set_time_index=True,\n",
")\n",
"display(explore_df)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}