{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(target-dj-compute-behav-bouts)=\n",
"# DataJoint pipeline: Computing behavioural bouts\n",
"\n",
":::{important}\n",
"This guide assumes you have a [DataJoint pipeline deployed](target-dj-pipeline-deployment) with [data already ingested](target-dj-data-ingestion-processing).\n",
":::\n",
"\n",
"Using position data from the [Aeon DataJoint pipeline](target-aeon-dj-pipeline), this guide walks through computing **sleep**, **drink**, and **explore** bouts for each subject in a multi-animal setup.\n",
"\n",
"You can also run this notebook online at [`works.datajoint.com`](https://works.datajoint.com/) using the following credentials:\n",
" - Username: aeondemo\n",
" - Password: aeon_djworks \n",
"\n",
"To access it, go to the Notebook tab at the top and in the File Browser on the left, navigate to `ucl-swc_aeon > docs > examples`, where this notebook `dj_compute_bouts.ipynb` is located.\n",
"\n",
":::{note}\n",
"The examples here use the _social_ period of the [social0.2-aeon4](target-full-datasets) dataset.\n",
"Since the social period spans 2 weeks, we limit retrieval to the first 24 hours to keep the examples concise.\n",
"\n",
"If you are using a different dataset, be sure to replace the experiment name and parameters in the code below accordingly.\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import libraries and define variables and helper functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from aeon.dj_pipeline import acquisition, streams, tracking"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def remove_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:\n",
" \"\"\"Detect and remove swaps in the position data.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data of a single subject.\n",
" max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame with swaps removed.\n",
" \"\"\"\n",
" dt = pos_df.index.diff().total_seconds()\n",
" dx = pos_df[\"x\"].diff()\n",
" dy = pos_df[\"y\"].diff()\n",
" pos_df[\"inst_speed\"] = np.sqrt(dx**2 + dy**2) / dt\n",
"\n",
" # Identify jumps\n",
" jumps = (pos_df[\"inst_speed\"] > max_speed)\n",
" shift_down = jumps.shift(1)\n",
" shift_down.iloc[0] = False\n",
" shift_up = jumps.shift(-1)\n",
" shift_up.iloc[len(jumps) - 1] = False\n",
" jump_starts = jumps & ~shift_down\n",
" jump_ends = jumps & ~shift_up\n",
" jump_start_indices = np.where(jump_starts)[0]\n",
" jump_end_indices = np.where(jump_ends)[0]\n",
"\n",
" if np.any(jumps):\n",
" # Ensure the lengths match\n",
" if len(jump_start_indices) > len(jump_end_indices): # jump-in-progress at start\n",
" jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)\n",
" elif len(jump_start_indices) < len(jump_end_indices): # jump-in-progress at end\n",
" jump_start_indices = np.insert(jump_start_indices, 0, 0)\n",
" # Remove jumps by setting speed to nan in jump regions and dropping nans\n",
" for start, end in zip(jump_start_indices, jump_end_indices, strict=True):\n",
" pos_df.loc[pos_df.index[start]:pos_df.index[end], \"inst_speed\"] = np.nan\n",
" pos_df.dropna(subset=[\"inst_speed\"], inplace=True)\n",
"\n",
" return pos_df"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"cm2px = 5.2 # 1 cm = 5.2 px roughly for top camera\n",
"nest_center = np.array((1215, 530))\n",
"nest_radius = 14 * cm2px # 14 cm, in px\n",
"exp = {\n",
" \"name\": \"social0.2-aeon4\",\n",
" \"presocial_start\": \"2024-01-31 11:00:00\",\n",
" \"presocial_end\": \"2024-02-08 15:00:00\",\n",
" \"social_start\": \"2024-02-09 17:00:00\",\n",
" \"social_end\": \"2024-02-23 12:00:00\",\n",
" \"postsocial_start\": \"2024-02-25 18:00:00\",\n",
" \"postsocial_end\": \"2024-03-02 13:00:00\",\n",
"}\n",
"key = {\"experiment_name\": exp[\"name\"]}\n",
"# Define periods\n",
"periods = {\n",
" \"presocial\": (exp[\"presocial_start\"], exp[\"presocial_end\"]),\n",
" \"social\": (exp[\"social_start\"], exp[\"social_end\"]),\n",
" \"postsocial\": (exp[\"postsocial_start\"], exp[\"postsocial_end\"]),\n",
"}\n",
"# Select the social period and limit to first 3 days for brevity\n",
"period_name = \"social\"\n",
"start = periods[period_name][0]\n",
"start_dt = datetime.strptime(start, \"%Y-%m-%d %H:%M:%S\")\n",
"end_dt = start_dt + pd.Timedelta(days=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fetch position data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Querying data from 2024-02-09 17:00:00 to 2024-02-10 17:00:00...\n",
" Retrieved 7162981 rows of position data\n"
]
}
],
"source": [
"def load_position_data(\n",
" key: dict[str, str], period_start: str, period_end: str\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads position data (centroid tracking) for a specified time period.\n",
"\n",
" Args:\n",
" key (dict): Key to identify experiment data (e.g., {\"experiment_name\": \"Exp1\"}).\n",
" period_start (str): Start datetime of the time period.\n",
" period_end (str): End datetime of the time period.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing position data for the specified period.\n",
" Returns an empty DataFrame if no data found.\n",
" \"\"\"\n",
" try:\n",
" print(f\" Querying data from {period_start} to {period_end}...\")\n",
"\n",
" # Create chunk restriction for the time period\n",
" chunk_restriction = acquisition.create_chunk_restriction(\n",
" key[\"experiment_name\"], period_start, period_end\n",
" )\n",
"\n",
" # Fetch centroid tracking data for the specified period\n",
" centroid_df = (\n",
" streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject\n",
" & key\n",
" & {\"spinnaker_video_source_name\": \"CameraTop\"}\n",
" & chunk_restriction\n",
" ).fetch(format=\"frame\")\n",
"\n",
" centroid_df = centroid_df.reset_index()\n",
" centroid_df = centroid_df.rename(\n",
" columns={\n",
" \"subject_name\": \"identity_name\",\n",
" \"timestamps\": \"time\",\n",
" \"subject_likelihood\": \"identity_likelihood\",\n",
" }\n",
" )\n",
" centroid_df = centroid_df.explode(\n",
" [\"time\", \"identity_likelihood\", \"x\", \"y\", \"likelihood\"]\n",
" )\n",
" centroid_df = centroid_df[\n",
" [\n",
" \"time\",\n",
" \"experiment_name\",\n",
" \"identity_name\",\n",
" \"identity_likelihood\",\n",
" \"x\",\n",
" \"y\",\n",
" \"likelihood\",\n",
" ]\n",
" ].set_index(\"time\")\n",
"\n",
" # Clean up the dataframe\n",
" if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:\n",
" if \"spinnaker_video_source_name\" in centroid_df.columns:\n",
" centroid_df.drop(columns=[\"spinnaker_video_source_name\"], inplace=True)\n",
" print(f\" Retrieved {len(centroid_df)} rows of position data\")\n",
" else:\n",
" print(\" No data found for the specified period\")\n",
"\n",
" return centroid_df\n",
"\n",
" except Exception as e:\n",
" print(\n",
" f\" Error loading position data for {key['experiment_name']} ({period_start} \"\n",
" f\"to {period_end}): {e}\"\n",
" )\n",
" return pd.DataFrame()\n",
"\n",
"\n",
"# Load position data\n",
"# If this takes too long, consider changing end_dt to an earlier time\n",
"position_df = load_position_data(key, start_dt, end_dt).sort_index()\n",
"\n",
"float_cols = ['identity_likelihood', 'x', 'y', 'likelihood']\n",
"for col in float_cols:\n",
" position_df[col] = pd.to_numeric(position_df[col], errors='coerce')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sleep bouts\n",
"\n",
"Using position data, we can infer sleep bouts for each subject by identifying sustained periods of low movement:\n",
"\n",
"1. Divide the data into consecutive 1-minute windows.\n",
"2. For each window, compute the maximum displacement between the furthest tracked points.\n",
"3. Label windows as sleep if displacement falls below a threshold (`move_thresh`, e.g. 4 cm converted to pixels).\n",
"4. Merge adjacent sleep windows into continuous sleep bouts.\n",
"5. Exclude bouts shorter than 2 minutes."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def sleep_bouts(\n",
" pos_df: pd.DataFrame,\n",
" subject: str,\n",
" move_thresh: float = 4 * cm2px, # cm -> px\n",
" max_speed: float = 100 * cm2px, # cm/s -> px/s\n",
") -> pd.DataFrame:\n",
" \"\"\"Returns sleep bouts for a given animal within the specified position data time period.\n",
"\n",
" Args:\n",
" pos_df (pd.DataFrame): DataFrame containing position data.\n",
" subject (str): Name of the animal to filter by.\n",
" move_thresh (float): Movement (in px) threshold to define sleep bouts.\n",
" max_speed (float): Maximum speed threshold for excising swaps.\n",
"\n",
" Returns:\n",
" pd.DataFrame: DataFrame containing sleep bouts for the specified animal.\n",
" \"\"\"\n",
" animal_data = pos_df[pos_df[\"identity_name\"] == subject].copy()\n",
" if animal_data.empty or not isinstance(animal_data, pd.DataFrame):\n",
" print(f\"No position data found for {subject}\")\n",
" return pd.DataFrame()\n",
"\n",
" # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`\n",
" sleep_win = pd.Timedelta(\"1m\")\n",
" sleep_windows_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\"]\n",
" )\n",
"\n",
" # Create time windows based on start and end time\n",
" data_start_time = animal_data.index.min()\n",
" data_end_time = animal_data.index.max()\n",
" window_starts = pd.date_range(\n",
" start=data_start_time, end=data_end_time, freq=sleep_win\n",
" )\n",
"\n",
" # Process each time window\n",
" pbar = tqdm(window_starts, desc=f\"Processing sleep bouts for {subject} \")\n",
" for win_start in pbar:\n",
" win_end = win_start + sleep_win\n",
" win_data = animal_data[\n",
" (animal_data.index >= win_start) & (animal_data.index < win_end)\n",
" ].copy()\n",
" if len(win_data) < 100: # skip windows with too little data\n",
" continue\n",
"\n",
" # Excise id swaps (based on pos / speed jumps)\n",
" # win_data = correct_swaps(win_data, max_speed)\n",
"\n",
" # Calculate the displacement - maximum distance between any two points in the window\n",
" dx = win_data[\"x\"].max() - win_data[\"x\"].min()\n",
" dy = win_data[\"y\"].max() - win_data[\"y\"].min()\n",
" displacement = np.sqrt(dx**2 + dy**2)\n",
"\n",
" # If displacement is less than threshold, consider it a sleep bout\n",
" if displacement < move_thresh:\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": win_start,\n",
" \"end\": win_end,\n",
" \"duration\": sleep_win,\n",
" }\n",
" sleep_windows_df = pd.concat(\n",
" [sleep_windows_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" # Now merge consecutive sleep windows into continuous bouts\n",
" if sleep_windows_df.empty or not isinstance(sleep_windows_df, pd.DataFrame):\n",
" return pd.DataFrame(columns=[\"subject\", \"start\", \"end\", \"duration\"])\n",
" # Initialize the merged bouts dataframe with the first window\n",
" sleep_bouts_df = pd.DataFrame(\n",
" [\n",
" {\n",
" \"subject\": subject,\n",
" \"start\": sleep_windows_df.iloc[0][\"start\"],\n",
" \"end\": sleep_windows_df.iloc[0][\"end\"],\n",
" \"duration\": sleep_windows_df.iloc[0][\"duration\"],\n",
" }\n",
" ]\n",
" )\n",
" # Iterate through remaining windows and merge consecutive ones\n",
" for i in range(1, len(sleep_windows_df)):\n",
" current_window = sleep_windows_df.iloc[i]\n",
" last_bout = sleep_bouts_df.iloc[-1]\n",
"\n",
" if current_window[\"start\"] == last_bout[\"end\"]: # continue bout\n",
" sleep_bouts_df.at[len(sleep_bouts_df) - 1, \"end\"] = current_window[\"end\"]\n",
" sleep_bouts_df.at[len(sleep_bouts_df) - 1, \"duration\"] = (\n",
" sleep_bouts_df.iloc[-1][\"end\"] - sleep_bouts_df.iloc[-1][\"start\"]\n",
" )\n",
" else: # start a new bout\n",
" new_bout = {\n",
" \"subject\": subject,\n",
" \"start\": current_window[\"start\"],\n",
" \"end\": current_window[\"end\"],\n",
" \"duration\": current_window[\"duration\"],\n",
" }\n",
" sleep_bouts_df = pd.concat(\n",
" [sleep_bouts_df, pd.DataFrame([new_bout])], ignore_index=True\n",
" )\n",
" # \n",
"\n",
" # Set min bout time\n",
" min_bout_time = pd.Timedelta(\"2m\")\n",
" sleep_bouts_df = sleep_bouts_df[sleep_bouts_df[\"duration\"] >= min_bout_time]\n",
"\n",
" return sleep_bouts_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compute sleep bouts for each subject\n",
"subjects = position_df[\"identity_name\"].unique()\n",
"sleep_bouts_df = pd.DataFrame(\n",
" columns=[\"subject\", \"start\", \"end\", \"duration\"]\n",
")\n",
"for subject in subjects:\n",
" subject_bouts = sleep_bouts(position_df, subject)\n",
" if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:\n",
" sleep_bouts_df = pd.concat(\n",
" [sleep_bouts_df, subject_bouts], ignore_index=True\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | subject | \n", "start | \n", "end | \n", "duration | \n", "
---|---|---|---|---|
0 | \n", "BAA-1104048 | \n", "2024-02-09 19:04:10.660 | \n", "2024-02-09 19:06:10.660 | \n", "0 days 00:02:00 | \n", "
1 | \n", "BAA-1104048 | \n", "2024-02-09 20:27:10.660 | \n", "2024-02-09 20:29:10.660 | \n", "0 days 00:02:00 | \n", "
2 | \n", "BAA-1104048 | \n", "2024-02-09 21:10:10.660 | \n", "2024-02-09 21:13:10.660 | \n", "0 days 00:03:00 | \n", "
3 | \n", "BAA-1104048 | \n", "2024-02-09 21:16:10.660 | \n", "2024-02-09 22:23:10.660 | \n", "0 days 01:07:00 | \n", "
4 | \n", "BAA-1104048 | \n", "2024-02-09 22:28:10.660 | \n", "2024-02-09 22:33:10.660 | \n", "0 days 00:05:00 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
128 | \n", "BAA-1104049 | \n", "2024-02-10 15:12:10.660 | \n", "2024-02-10 15:15:10.660 | \n", "0 days 00:03:00 | \n", "
129 | \n", "BAA-1104049 | \n", "2024-02-10 15:16:10.660 | \n", "2024-02-10 15:26:10.660 | \n", "0 days 00:10:00 | \n", "
130 | \n", "BAA-1104049 | \n", "2024-02-10 15:28:10.660 | \n", "2024-02-10 15:39:10.660 | \n", "0 days 00:11:00 | \n", "
131 | \n", "BAA-1104049 | \n", "2024-02-10 16:06:10.660 | \n", "2024-02-10 16:08:10.660 | \n", "0 days 00:02:00 | \n", "
132 | \n", "BAA-1104049 | \n", "2024-02-10 16:55:10.660 | \n", "2024-02-10 17:01:10.660 | \n", "0 days 00:06:00 | \n", "
133 rows × 4 columns
\n", "\n", " | subject | \n", "start | \n", "end | \n", "duration | \n", "
---|---|---|---|---|
0 | \n", "BAA-1104048 | \n", "2024-02-09 16:53:55.400 | \n", "2024-02-09 16:54:02.300 | \n", "0 days 00:00:06.900000 | \n", "
1 | \n", "BAA-1104048 | \n", "2024-02-09 17:08:33.300 | \n", "2024-02-09 17:08:44.100 | \n", "0 days 00:00:10.800000 | \n", "
2 | \n", "BAA-1104048 | \n", "2024-02-09 17:08:44.200 | \n", "2024-02-09 17:08:50.400 | \n", "0 days 00:00:06.200000 | \n", "
3 | \n", "BAA-1104048 | \n", "2024-02-09 17:15:14.000 | \n", "2024-02-09 17:15:24.700 | \n", "0 days 00:00:10.700000 | \n", "
4 | \n", "BAA-1104048 | \n", "2024-02-09 17:43:53.900 | \n", "2024-02-09 17:44:03.800 | \n", "0 days 00:00:09.900000 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
213 | \n", "BAA-1104049 | \n", "2024-02-10 17:26:44.400 | \n", "2024-02-10 17:27:22.100 | \n", "0 days 00:00:37.700000 | \n", "
214 | \n", "BAA-1104049 | \n", "2024-02-10 17:34:49.400 | \n", "2024-02-10 17:35:19.800 | \n", "0 days 00:00:30.400000 | \n", "
215 | \n", "BAA-1104049 | \n", "2024-02-10 17:40:56.100 | \n", "2024-02-10 17:41:38.900 | \n", "0 days 00:00:42.800000 | \n", "
216 | \n", "BAA-1104049 | \n", "2024-02-10 17:41:51.700 | \n", "2024-02-10 17:42:34.100 | \n", "0 days 00:00:42.400000 | \n", "
217 | \n", "BAA-1104049 | \n", "2024-02-10 17:42:34.200 | \n", "2024-02-10 17:42:58.800 | \n", "0 days 00:00:24.600000 | \n", "
218 rows × 4 columns
\n", "\n", " | subject | \n", "start | \n", "end | \n", "duration | \n", "
---|---|---|---|---|
0 | \n", "BAA-1104048 | \n", "2024-02-09 16:48:10.660 | \n", "2024-02-09 16:53:10.660 | \n", "0 days 00:05:00 | \n", "
1 | \n", "BAA-1104048 | \n", "2024-02-09 16:54:10.660 | \n", "2024-02-09 17:08:10.660 | \n", "0 days 00:14:00 | \n", "
2 | \n", "BAA-1104048 | \n", "2024-02-09 17:09:10.660 | \n", "2024-02-09 17:27:10.660 | \n", "0 days 00:18:00 | \n", "
3 | \n", "BAA-1104048 | \n", "2024-02-09 17:29:10.660 | \n", "2024-02-09 17:43:10.660 | \n", "0 days 00:14:00 | \n", "
4 | \n", "BAA-1104048 | \n", "2024-02-09 17:44:10.660 | \n", "2024-02-09 17:53:10.660 | \n", "0 days 00:09:00 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
131 | \n", "BAA-1104049 | \n", "2024-02-10 17:03:10.660 | \n", "2024-02-10 17:09:10.660 | \n", "0 days 00:06:00 | \n", "
132 | \n", "BAA-1104049 | \n", "2024-02-10 17:17:10.660 | \n", "2024-02-10 17:26:10.660 | \n", "0 days 00:09:00 | \n", "
133 | \n", "BAA-1104049 | \n", "2024-02-10 17:28:10.660 | \n", "2024-02-10 17:34:10.660 | \n", "0 days 00:06:00 | \n", "
134 | \n", "BAA-1104049 | \n", "2024-02-10 17:35:10.660 | \n", "2024-02-10 17:41:10.660 | \n", "0 days 00:06:00 | \n", "
135 | \n", "BAA-1104049 | \n", "2024-02-10 17:43:10.660 | \n", "2024-02-10 18:00:10.660 | \n", "0 days 00:17:00 | \n", "
136 rows × 4 columns
\n", "