Platform paper social experiment analysis: Part 1#
In this example we will work with behavioural data collected from experiments social0.2
, social0.3
, and social0.4
, in which two mice foraged for food in the habitat with three foraging patches whose reward rates changed dynamically over time.
The experiments each consist of three periods:
“presocial”, in which each mouse was in the habitat alone for 3-4 days.
“social”, in which both mice were in the habitat together for 2 weeks.
“postsocial”, in which each mouse was in the habitat alone again for 3-4 days.
The goal of the experiments was to understand how the mice’s behaviour changes as they learn to forage for food in the habitat, and how their behaviour differs between social vs. solo settings.
The full datasets are available on the Datasets page but for the purpose of this example, we will be using the precomputed Platform paper social analysis datasets.
See also
“Extended Data Fig. 7”, in “Extended Data” in the “Supplementary Material” of the platform paper for a detailed description of the experiments.
Below is a brief explanation of how the environment (i.e. patch properties) changed over blocks (60–180 minute periods of time):
Every block begins at a random interval \(t\):
\[ t \sim \mathrm{Uniform}(60,\,180) \quad \text{In minutes} \]At the start of each block, sample a row from the predefined matrix \(\lambda_{\mathrm{set}}\):
\[\begin{split} \lambda_{\mathrm{set}} = \begin{pmatrix} 1 & 1 & 1 \\ 5 & 5 & 5 \\ 1 & 3 & 5 \\ 1 & 5 & 3 \\ 3 & 1 & 5 \\ 3 & 5 & 1 \\ 5 & 1 & 3 \\ 5 & 3 & 1 \\ \end{pmatrix} \quad \text{In meters} \end{split}\]Assign the sampled row to specific patch means \(\lambda_{\mathrm{1}}, \lambda_{\mathrm{2}}, \lambda_{\mathrm{3}}\) and apply a constant offset \(c\) to all thresholds:
\[\begin{split} \begin{aligned} \lambda_{\mathrm{1}}, \lambda_{\mathrm{2}}, \lambda_{\mathrm{3}} &\sim \mathrm{Uniform}(\lambda_{\mathrm{set}}) \\ c &= 0.75 \end{aligned} \quad \text{Patch means and offset} \end{split}\]Sample a value from each of \(P_{\mathrm{1}}, P_{\mathrm{2}}, P_{\mathrm{3}}\) as the initial threshold for the respective patch. Whenever a patch reaches its threshold, resample a new value from its corresponding distribution:
\[\begin{split} \begin{aligned} P_{\mathrm{1}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{1}}) \\ P_{\mathrm{2}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{2}}) \\ P_{\mathrm{3}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{3}}) \end{aligned} \quad \text{Patch distributions} \end{split}\]
Set up environment#
Create and activate a virtual environment named social-analysis
using uv.
uv venv aeon-social-analysis --python ">=3.11"
source aeon-social-analysis/bin/activate # Unix
.\aeon-social-analysis\Scripts\activate # Windows
Install the required ssm
package and its dependencies.
uv pip install matplotlib numpy pandas plotly seaborn statsmodels pyyaml pyarrow tqdm scipy jupyter
Import libraries and define variables and helper functions#
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
import statsmodels.api as sm
import yaml
from plotly.subplots import make_subplots
from scipy.stats import binomtest, ttest_rel, wilcoxon
from tqdm.notebook import tqdm
Note
Change data_dir
and save_dir
to the paths where your local dataset (the parquet files) is stored and where you want to save the results.
# SET THESE VARIABLES ACCORDINGLY
data_dir = Path("")
save_dir = Path("")
# Load metadata
os.makedirs(data_dir, exist_ok=True)
with open(data_dir / "Metadata_aeon3.yml", "r") as file:
metadata_aeon3 = yaml.safe_load(file)
with open(data_dir / "Metadata_aeon4.yml", "r") as file:
metadata_aeon4 = yaml.safe_load(file)
Dominance plots#
# Parameters
DISTANCE_THRESH = 50 # px
MIN_DUR, MAX_DUR = 2, 5 # seconds
PADDING_SEC = 2
SAMPLE_SIZE = 100
# Load all periods for a given experiment
experiment = experiments[0]
data = load_experiment_data(
experiment=experiment,
data_dir=data_dir,
periods=["social"],
data_types=["retreat", "position"],
)
social_retreat_df = data["social_retreat"]
social_position_df = data["social_position"]
exp, acq_pc = experiment["name"].split("-", 1)
metadata = metadata_aeon3 if acq_pc == "aeon3" else metadata_aeon4
inner_radius = float(metadata["ActiveRegion"]["ArenaInnerRadius"])
outer_radius = float(metadata["ActiveRegion"]["ArenaOuterRadius"])
center_x = float(metadata["ActiveRegion"]["ArenaCenter"]["X"])
center_y = float(metadata["ActiveRegion"]["ArenaCenter"]["Y"])
nest_y1 = float(metadata["ActiveRegion"]["NestRegion"]["ArrayOfPoint"][1]["Y"])
nest_y2 = float(metadata["ActiveRegion"]["NestRegion"]["ArrayOfPoint"][2]["Y"])
gate_width = 20
gate_coordinates = []
for device in metadata["Devices"]:
if "Gate" in device and "Rfid" in device:
gate_coordinates.append(metadata["Devices"][device]["Location"])
# Compute squared distance from arena center
social_position_df["dist2"] = (social_position_df["x"] - center_x) ** 2 + (
social_position_df["y"] - center_y
) ** 2
# Build “in‐corridor” mask (between inner & outer radii)
mask_corridor = social_position_df["dist2"].between(inner_radius**2, outer_radius**2)
# Exclude the nest region (to the right of center, between nest_y1 & nest_y2)
mask_nest = ~(
(social_position_df["x"] > center_x)
& (social_position_df["y"].between(nest_y1, nest_y2))
)
# Exclude all gate regions (within gate_width of any gate)
mask_gate = pd.Series(True, index=social_position_df.index)
for loc in gate_coordinates:
gx, gy = float(loc["X"]), float(loc["Y"])
d2 = (social_position_df["x"] - gx) ** 2 + (social_position_df["y"] - gy) ** 2
mask_gate &= d2 > gate_width**2
# Combine spatial masks
spatial_mask = mask_corridor & mask_nest & mask_gate
# Apply spatial filter
df_spatial = social_position_df.loc[spatial_mask].copy()
df_spatial = df_spatial.reset_index()
# Exclude retreat‐event timestamps
df_spatial["time"] = pd.to_datetime(df_spatial["time"])
social_retreat_df["start_timestamp"] = pd.to_datetime(
social_retreat_df["start_timestamp"]
)
social_retreat_df["end_timestamp"] = pd.to_datetime(social_retreat_df["end_timestamp"])
# Build a mask for any retreat interval
mask_retreat = pd.Series(False, index=df_spatial.index)
for start, end in zip(
social_retreat_df["start_timestamp"], social_retreat_df["end_timestamp"]
):
mask_retreat |= df_spatial["time"].between(start, end)
# Final filtered DataFrame
df_filtered = (
df_spatial.loc[~mask_retreat] # drop retreat frames
.drop(columns=["dist2"]) # clean up helper column
.reset_index(drop=True)
)
# Run proximity and filter durations
starts, ends = extract_proximity_periods(df_filtered, DISTANCE_THRESH)
# Compute durations & keep short ones
durations = [(e - s).total_seconds() for s, e in zip(starts, ends)]
periods = [
(s, e, d) for s, e, d in zip(starts, ends, durations) if MIN_DUR < d < MAX_DUR
]
periods.sort(key=lambda x: x[2])
filtered_starts = [p[0] for p in periods]
filtered_ends = [p[1] for p in periods]
filtered_durations = [p[2] for p in periods]
Load data#
social_retreat_df_all_exps = []
social_fight_df_all_exps = []
social_patchinfo_df_all_exps = []
social_patch_df_all_exps = []
pbar = tqdm(
[experiments[i] for i in [0, 1, 4, 5]],
desc="Loading experiments",
unit="experiment",
)
for exp in pbar:
data = load_experiment_data(
experiment=exp,
data_dir=data_dir,
periods=["social"],
data_types=["retreat", "fight", "patchinfo", "patch"],
# trim_days=1 # Optional: trim
)
df_retreat = data["social_retreat"]
df_fight = data["social_fight"]
social_patchinfo_df = data["social_patchinfo"]
social_patch_df = data["social_patch"]
df_retreat["experiment_name"] = exp["name"]
df_fight["experiment_name"] = exp["name"]
social_retreat_df_all_exps.append(df_retreat)
social_fight_df_all_exps.append(df_fight)
social_patchinfo_df_all_exps.append(social_patchinfo_df)
social_patch_df_all_exps.append(social_patch_df)
social_retreat_df_all_exps = pd.concat(social_retreat_df_all_exps, ignore_index=True)
social_fight_df_all_exps = pd.concat(social_fight_df_all_exps, ignore_index=True)
social_patchinfo_df_all_exps = pd.concat(
social_patchinfo_df_all_exps, ignore_index=True
)
social_patch_df_all_exps = pd.concat(social_patch_df_all_exps, ignore_index=True)
tube_test_data = {
"social0.2-aeon3": {
"BAA-1104045": {"pre": 2, "post": 1},
"BAA-1104047": {"pre": 8, "post": 9},
},
"social0.2-aeon4": {
"BAA-1104048": {"pre": 7, "post": 8},
"BAA-1104049": {"pre": 3, "post": 2},
},
"social0.4-aeon3": {
"BAA-1104794": {"pre": 4, "post": 2},
"BAA-1104792": {"pre": 12, "post": 13},
},
"social0.4-aeon4": {
"BAA-1104795": {"pre": 10, "post": 12},
"BAA-1104797": {"pre": 4, "post": 3},
},
}
Tube test results, for reference:
SOCIAL 0.2
Pre-social tube test results:
BAA-1104045: 2, BAA-1104047: 8
BAA-1104048: 7, BAA-1104049: 3
Post-social tube test results:
BAA-1104045: 1, BAA-1104047: 9
BAA-1104048: 8, BAA-1104049: 2
SOCIAL 0.4
Pre-social tube test results:
BAA-1104795: 10, BAA-1104797: 4
BAA-1104792: 4, BAA-1104794: 12
Post-social tube test results:
BAA-1104795: 12, BAA-1104797: 3
BAA-1104792: 2, BAA-1104794: 13
1. Fights and retreats raster plots#
# Build lookup of social_start & compute per-exp offset to align at earliest start-time
exp_df = pd.DataFrame(experiments)
exp_df["social_start"] = pd.to_datetime(exp_df["social_start"])
exp_df["tod_sec"] = (
exp_df["social_start"].dt.hour * 3600
+ exp_df["social_start"].dt.minute * 60
+ exp_df["social_start"].dt.second
)
# Earliest start-time (seconds since midnight)
min_tod = exp_df["tod_sec"].min()
# Offset seconds so each experiment’s Day 0 lines up at min_tod
exp_df["offset_sec"] = exp_df["tod_sec"] - min_tod
# Compute absolute baseline timestamp per experiment
exp_df["baseline_ts"] = exp_df["social_start"] - pd.to_timedelta(
exp_df["offset_sec"], unit="s"
)
exp_df = exp_df[["name", "baseline_ts"]]
# Keep original experiment order
exp_order = [e["name"] for e in experiments]
dark_color = "#555555"
# Dataframes and titles to plot
raster_configs = [
(social_fight_df_all_exps, "Raster Plot: Social Fights", "fights"),
(social_retreat_df_all_exps, "Raster Plot: Social Retreats", "retreats"),
]
# Print aligned Day 0 hour (earliest start-time)
aligned_hour = min_tod // 3600
print(f"Aligned Day 0 starts at hour {aligned_hour}")
# Compute global end time across both datasets
all_events = pd.concat(
[
social_fight_df_all_exps[["experiment_name", "start_timestamp"]],
social_retreat_df_all_exps[["experiment_name", "start_timestamp"]],
]
)
all_events["start_timestamp"] = pd.to_datetime(all_events["start_timestamp"])
# Merge baseline timestamps
all_events = all_events.merge(
exp_df, left_on="experiment_name", right_on="name", how="left"
)
# Compute seconds since baseline
all_events["rel_sec"] = (
all_events["start_timestamp"] - all_events["baseline_ts"]
).dt.total_seconds()
# Find last event
last = all_events.loc[all_events["rel_sec"].idxmax()]
# Compute end day and hour
end_day = int(last["rel_sec"] // 86400)
end_hour = int((last["rel_sec"] % 86400) // 3600)
print(f"Global end at Day {end_day}, hour {end_hour}")
# Plot each raster
for df, title, behavior in raster_configs:
df2 = df[["experiment_name", "start_timestamp"]].copy()
df2["start_timestamp"] = pd.to_datetime(df2["start_timestamp"])
df2 = df2.merge(exp_df, left_on="experiment_name", right_on="name", how="left")
# Compute rel_days
df2["rel_days"] = (
df2["start_timestamp"] - df2["baseline_ts"]
).dt.total_seconds() / 86400.0
# Enforce experiment ordering
df2["experiment_name"] = pd.Categorical(
df2["experiment_name"], categories=exp_order, ordered=True
)
# Draw
fig = go.Figure(
go.Scatter(
x=df2["rel_days"],
y=df2["experiment_name"],
mode="markers",
marker=dict(symbol="line-ns", color=dark_color, size=8, line_width=1.2),
showlegend=False,
hovertemplate="Day: %{x:.2f}<br>Experiment: %{y}<extra></extra>",
)
)
fig.update_yaxes(
autorange="reversed",
title_text="Experiment",
showgrid=False,
zeroline=False,
showline=False,
ticks="",
)
fig.update_xaxes(
title_text="Days since aligned Day 0",
tick0=0,
dtick=1,
tickformat=".0f",
showgrid=False,
zeroline=False,
showline=True,
linecolor="black",
)
fig.update_layout(
template="simple_white",
margin=dict(l=120, r=20, t=60, b=40),
height=100 + 20 * len(exp_order),
title=dict(text=title, x=0.5),
)
# pio.write_image(
# fig,
# save_dir / f"{behavior}_raster.svg",
# format="svg"
# )
fig.show()
Aligned Day 0 starts at hour 10.0
Global end at Day 14, hour 1


3. Pre/during/post tube test comparison#
# Total win counts per subject during social phase
social_counts_df = (
social_retreat_df_all_exps.groupby(["experiment_name", "winner_identity"])
.size()
.unstack(fill_value=0)
)
# Compute per-experiment summary metrics for dominant subject
records = []
for exp, subj_dict in tube_test_data.items():
soc_counts = social_counts_df.loc[exp]
total_pre = sum(d["pre"] for d in subj_dict.values())
total_post = sum(d["post"] for d in subj_dict.values())
total_soc = soc_counts.sum()
# Identify dominant by total wins across all phases
total_wins = {
subj: subj_dict[subj]["pre"] + soc_counts.get(subj, 0) + subj_dict[subj]["post"]
for subj in subj_dict
}
dominant = max(total_wins, key=total_wins.get)
# Compute win rates for dominant subject
pre_rate = subj_dict[dominant]["pre"] / total_pre
social_rate = soc_counts[dominant] / total_soc
post_rate = subj_dict[dominant]["post"] / total_post
baseline_rate = (subj_dict[dominant]["pre"] + subj_dict[dominant]["post"]) / (
total_pre + total_post
)
records.append(
{
"experiment": exp,
"dominant": dominant,
"pre_rate": pre_rate,
"social_rate": social_rate,
"post_rate": post_rate,
"baseline_rate": baseline_rate,
}
)
summary_df = pd.DataFrame(records)
# Test: is social-phase rate higher than pre+post (Wilcoxon signed-rank)
stat, p_value = wilcoxon(summary_df["baseline_rate"], summary_df["social_rate"])
print(f"Wilcoxon signed-rank: W = {stat:.3f}, p-value = {p_value:.3f}\n")
# Per-experiment plotting and binomial tests
x_positions = [0, 1, 2]
x_labels = ["Tube test (pre)", "Two weeks foraging", "Tube test (post)"]
for experiment, subjects in tube_test_data.items():
ids = list(subjects.keys())
pre_scores = {s: subjects[s]["pre"] for s in ids}
post_scores = {s: subjects[s]["post"] for s in ids}
soc_counts = social_counts_df.loc[experiment]
# Identify dominant and subordinate
dominant = summary_df.loc[summary_df["experiment"] == experiment, "dominant"].iloc[
0
]
subordinate = [s for s in ids if s != dominant][0]
# Win/loss counts
k_pre = pre_scores[dominant]
n_pre = k_pre + pre_scores[subordinate]
k_post = post_scores[dominant]
n_post = k_post + post_scores[subordinate]
k_social = soc_counts[dominant]
n_social = soc_counts.sum()
# Binomial tests (one-sided: dominant > 50%)
k_pool = k_pre + k_post
n_pool = n_pre + n_post
p_pool = binomtest(k_pool, n_pool, p=0.5, alternative="greater").pvalue
p_social = binomtest(k_social, n_social, p=0.5, alternative="greater").pvalue
# Print results
print(f"\n=== {experiment} ({dominant} dominant) ===")
print(f"Pooled Pre+Post: wins={k_pool}/{n_pool} → binom p={p_pool:.3f}")
print(f"Social: wins={k_social}/{n_social} → binom p={p_social:.3f}")
# Plot win rates for both mice
proportions = {
dominant: {
"pre": k_pre / n_pre,
"social": k_social / n_social,
"post": k_post / n_post,
"color": "gold",
},
subordinate: {
"pre": pre_scores[subordinate] / n_pre,
"social": soc_counts[subordinate] / n_social,
"post": post_scores[subordinate] / n_post,
"color": "brown",
},
}
fig = go.Figure()
for subj, vals in proportions.items():
for i, x in enumerate(x_positions):
y = [vals["pre"], vals["social"], vals["post"]][i]
fig.add_trace(
go.Scatter(
x=[x, x],
y=[0, y],
mode="lines+markers",
name=subj,
line=dict(color=vals["color"]),
showlegend=(i == 0),
)
)
fig.update_layout(
title=experiment,
plot_bgcolor="white",
width=400,
height=400,
xaxis=dict(tickvals=x_positions, ticktext=x_labels, range=[-0.5, 2.5]),
yaxis=dict(
title="Proportion of victories",
range=[0, 1],
ticks="outside",
ticklen=5,
tickwidth=2,
showline=True,
linecolor="black",
),
)
fig.show()
# if experiment == experiments[-1]['name']:
# pio.write_image(
# fig,
# str(save_dir / f"tube_test_win_rates.svg"),
# format="svg"
# )
Wilcoxon signed-rank: W = 2.000, p-value = 0.375
=== social0.2-aeon3 (BAA-1104047 dominant) ===
Pooled Pre+Post: wins=17/20 → binom p=0.001
Social: wins=35/55 → binom p=0.029

=== social0.2-aeon4 (BAA-1104048 dominant) ===
Pooled Pre+Post: wins=15/20 → binom p=0.021
Social: wins=183/251 → binom p=0.000

=== social0.4-aeon3 (BAA-1104792 dominant) ===
Pooled Pre+Post: wins=25/31 → binom p=0.000
Social: wins=47/66 → binom p=0.000

=== social0.4-aeon4 (BAA-1104795 dominant) ===
Pooled Pre+Post: wins=22/29 → binom p=0.004
Social: wins=139/172 → binom p=0.000

4. Dominance summary plot#
social_win_proportions = (
social_retreat_df_all_exps.groupby(["experiment_name", "winner_identity"])
.size()
.unstack(fill_value=0)
)
social_win_proportions = social_win_proportions.div(
social_win_proportions.sum(axis=1), axis=0
)
# Define plot
fig = go.Figure()
legend_labels_added = {"gold": False, "brown": False}
for experiment, subjects in tube_test_data.items():
ids = list(subjects.keys())
pre_scores = {subj: subjects[subj]["pre"] for subj in ids}
post_scores = {subj: subjects[subj]["post"] for subj in ids}
pre_dominant = max(pre_scores, key=pre_scores.get)
post_dominant = max(post_scores, key=post_scores.get)
if experiment in social_win_proportions.index:
social_scores = social_win_proportions.loc[experiment]
social_dominant = social_scores.idxmax()
else:
raise ValueError(f"No social data for {experiment}")
if len({pre_dominant, post_dominant, social_dominant}) != 1:
raise ValueError(
f"Inconsistent dominant subject in {experiment}: pre={pre_dominant}, social={social_dominant}, post={post_dominant}"
)
dominant = pre_dominant
subordinate = [s for s in ids if s != dominant][0]
total_pre = pre_scores[dominant] + pre_scores[subordinate]
total_post = post_scores[dominant] + post_scores[subordinate]
social_dom = (
social_win_proportions.at[experiment, dominant]
if dominant in social_win_proportions.columns
else 0
)
social_sub = (
social_win_proportions.at[experiment, subordinate]
if subordinate in social_win_proportions.columns
else 0
)
proportions = {
dominant: {
"pre": -pre_scores[dominant] / total_pre,
"post": post_scores[dominant] / total_post,
"social": -social_dom,
"color": "gold",
"legend": "Dominant",
},
subordinate: {
"pre": -pre_scores[subordinate] / total_pre,
"post": post_scores[subordinate] / total_post,
"social": -social_sub,
"color": "brown",
"legend": "Subordinate",
},
}
for subj, vals in proportions.items():
color = vals["color"]
legend_name = vals["legend"] if not legend_labels_added[color] else None
legend_labels_added[color] = True
# Pre point
fig.add_trace(
go.Scatter(
x=[vals["pre"]],
y=[vals["social"]],
mode="markers",
marker=dict(color=color, size=10),
name=legend_name,
showlegend=legend_name is not None,
)
)
# Post point
fig.add_trace(
go.Scatter(
x=[vals["post"]],
y=[-vals["social"]],
mode="markers",
marker=dict(color=color, size=10),
name=legend_name,
showlegend=False,
)
)
# Final layout
# Hide everything
fig.update_xaxes(
showgrid=False, showline=False, zeroline=False, showticklabels=False, ticks=""
)
fig.update_yaxes(
showgrid=False, showline=False, zeroline=False, showticklabels=False, ticks=""
)
# Draw the two axes as black lines
fig.add_shape(type="line", x0=-1, x1=1, y0=0, y1=0, line=dict(color="black", width=1))
fig.add_shape(type="line", x0=0, x1=0, y0=-1, y1=1, line=dict(color="black", width=1))
# Draw little ticks at ±1
tick_len = 0.02
for t in (-1, 1):
# x‐axis tick
fig.add_shape(
type="line",
x0=t,
x1=t,
y0=-tick_len,
y1=+tick_len,
line=dict(color="black", width=1),
)
# y‐axis tick
fig.add_shape(
type="line",
x0=-tick_len,
x1=+tick_len,
y0=t,
y1=t,
line=dict(color="black", width=1),
)
# Annotate labels at ±1
for t, txt in [(-1, "−1"), (1, "1")]:
# x‐axis label
fig.add_annotation(
x=t,
y=0,
text=txt,
yshift=-16, # move it down in px
showarrow=False,
font=dict(size=14),
)
# y‐axis label
fig.add_annotation(
x=0,
y=t,
text=txt,
xshift=16, # move it left in px
showarrow=False,
font=dict(size=14),
)
# Finally, re-add title & legend layout
fig.update_layout(
title="Pre vs Post Tube Test (Reflected by Social Proportion)",
plot_bgcolor="white",
width=700,
height=700,
legend=dict(title="Subject role", orientation="v", x=1.02, y=1),
)
# Draw y=x as a dashed grey line (behind the points)
fig.add_shape(
type="line",
x0=-1,
y0=-1,
x1=1,
y1=1,
line=dict(color="lightgrey", width=1, dash="dash"),
layer="below",
)
# Re-add title, legend, and now axis titles
fig.update_layout(
title="Pre vs Post Tube Test (Reflected by Social Proportion)",
xaxis_title="Tube test victory proportion (−pre, +post)",
yaxis_title="Social phase proportion (−plotted pre, +plotted post)",
plot_bgcolor="white",
width=700,
height=700,
legend=dict(title="Subject role", orientation="v", x=1.02, y=1),
)
fig.show()
# pio.write_image(
# fig,
# str(save_dir / "tube_test_scatter.svg"),
# format="svg"
# )

5. Dominant vs subordinate comparison of time spent and distance spun at best patch#
df = social_patch_df_all_exps.merge(
social_patchinfo_df_all_exps[
["experiment_name", "block_start", "patch_name", "patch_rate"]
],
on=["experiment_name", "block_start", "patch_name"],
how="left",
).assign(dummy=lambda x: x["patch_name"].str.contains("dummy", case=False))
# Keep only blocks where non-dummy patches have exactly 3 different rates
df = df[
df.groupby(["experiment_name", "block_start"])["patch_rate"].transform(
lambda s: s[~df.loc[s.index, "dummy"]].nunique() == 3
)
]
# Create a ranking only for non-dummy patches
non_dummy_ranks = (
df[~df["dummy"]]
.groupby(["experiment_name", "block_start"])["patch_rate"]
.rank(method="dense")
)
# Add the ranks back to the full dataframe
df["patch_rank"] = np.nan
df.loc[~df["dummy"], "patch_rank"] = non_dummy_ranks
# Assign difficulty (rank 1=hard, rank 3=easy, rank 2=medium)
df["patch_difficulty"] = np.where(
df["dummy"],
"dummy",
np.where(
df["patch_rank"] == 1, "hard", np.where(df["patch_rank"] == 3, "easy", "medium")
),
)
df = df.drop(columns=["patch_rank"])
# Compute social‐win proportions
swp = (
social_retreat_df_all_exps.groupby(["experiment_name", "winner_identity"])
.size()
.unstack(fill_value=0)
)
swp = swp.div(swp.sum(axis=1), axis=0)
# Build the dominance DataFrame in one comprehension
dominance_df = pd.DataFrame(
[
{
"experiment_name": exp,
"dominant": dom,
"subordinate": next(s for s in subs if s != dom),
}
for exp, subs in tube_test_data.items()
if exp in swp.index
# pick the pre‐tube top scorer…
for dom in [max(subs, key=lambda s: subs[s]["pre"])]
# …only keep if post‐tube and social‐win agree
if dom == max(subs, key=lambda s: subs[s]["post"]) == swp.loc[exp].idxmax()
]
)
dominance_df
experiment_name | dominant | subordinate | |
---|---|---|---|
0 | social0.2-aeon3 | BAA-1104047 | BAA-1104045 |
1 | social0.2-aeon4 | BAA-1104048 | BAA-1104049 |
2 | social0.4-aeon3 | BAA-1104792 | BAA-1104794 |
3 | social0.4-aeon4 | BAA-1104795 | BAA-1104797 |
# Compute per‐subject, per‐block easy‐time fraction, but only keep blocks ≥1 min
time_df = (
df.groupby(["experiment_name", "block_start", "subject_name", "patch_difficulty"])[
"in_patch_time"
]
.sum()
.unstack("patch_difficulty", fill_value=0)
.assign(total_time=lambda d: d.sum(axis=1))
# filter out blocks with < 60 s total patch time
.loc[lambda d: d["total_time"] >= 60]
.assign(easy_ratio=lambda d: d["easy"] / d["total_time"])
.reset_index()
)
# Merge in dominant/subordinate labels and tag role
time_df = time_df.merge(dominance_df, on="experiment_name", how="left").assign(
role=lambda d: np.where(
d["subject_name"] == d["dominant"], "dominant", "subordinate"
)
)
# Box‐plot of easy_ratio by role
fig = px.box(
time_df,
x="role",
y="easy_ratio",
points="all",
category_orders={"role": ["dominant", "subordinate"]},
title="Relative time in easy patch by role (≥ 1 min total time spent in patches per block)",
)
fig.update_yaxes(title="Easy‐patch time / total patch time")
fig.update_xaxes(title="")
fig.show()
# Print the medians
medians = time_df.groupby("role")["easy_ratio"].median()
print("Median easy patch time ratio by role:")
for role, median in medians.items():
print(f"{role}: {median:.2f}")

Median easy patch time ratio by role:
dominant: 0.52
subordinate: 0.53
# Get per‐subject, per‐block time by difficulty
time_df = (
df.groupby(["experiment_name", "block_start", "subject_name", "patch_difficulty"])[
"in_patch_time"
]
.sum()
.unstack("patch_difficulty", fill_value=0)
# drop any block‐subject that didn’t spend ≥10 s in each of hard/medium/easy
.loc[lambda d: (d[["hard", "medium", "easy"]] >= 10).all(axis=1)]
.assign(
total_time=lambda d: d.sum(axis=1),
easy_ratio=lambda d: d["easy"] / d["total_time"],
)
.reset_index()
)
# Merge in dominant/subordinate and label role
time_df = time_df.merge(dominance_df, on="experiment_name", how="left").assign(
role=lambda d: np.where(
d["subject_name"] == d["dominant"], "dominant", "subordinate"
)
)
# Box‐plot of easy_ratio by role
fig = px.box(
time_df,
x="role",
y="easy_ratio",
points="all",
category_orders={"role": ["dominant", "subordinate"]},
title="Easy‐patch fraction by role (≥10 s in each non‐dummy patch)",
)
fig.update_yaxes(title="Easy time / total patch time")
fig.update_xaxes(title="")
fig.show()

# Compute true traveled distance for each patch‐visit, thresholding <1→0
df2 = (
df.assign(
# sum of abs steps = true distance
traveled=lambda d: d["wheel_cumsum_distance_travelled"].apply(
lambda arr: np.sum(np.abs(np.diff(arr))) if len(arr) > 1 else 0
)
)
# any tiny (<1) distance becomes 0
.assign(traveled=lambda d: d["traveled"].mask(d["traveled"] < 1, 0))
)
# Pivot to one row per subject‐block with columns [hard, medium, easy]
dist_df = (
df2.groupby(["experiment_name", "block_start", "subject_name", "patch_difficulty"])[
"traveled"
]
.sum()
.unstack("patch_difficulty", fill_value=0)
.assign(
total_dist=lambda d: d.sum(axis=1),
easy_ratio=lambda d: d["easy"] / d["total_dist"],
)
.reset_index()
)
# Merge in dominance & tag role
dist_df = dist_df.merge(dominance_df, on="experiment_name", how="left").assign(
role=lambda d: np.where(
d["subject_name"] == d["dominant"], "dominant", "subordinate"
)
)
# Box‐plot of easy_ratio by role
fig = px.box(
dist_df,
x="role",
y="easy_ratio",
points="all",
category_orders={"role": ["dominant", "subordinate"]},
title="Easy‐patch distance fraction by role (tiny spins <1 set to 0)",
)
fig.update_yaxes(title="Easy‐patch distance / total distance")
fig.update_xaxes(title="")
fig.show()
# Print the medians
medians = dist_df.groupby("role")["easy_ratio"].median()
print("Median easy patch distance ratio by role:")
for role, median in medians.items():
print(f"{role}: {median:.2f}")

Median easy patch distance ratio by role:
dominant: 0.33
subordinate: 0.33
# Compute true traveled distance per patch‐visit, tiny spins → 0
df2 = df.assign(
traveled=lambda d: d["wheel_cumsum_distance_travelled"].apply(
lambda arr: np.sum(np.abs(np.diff(arr))) if len(arr) > 1 else 0
)
).assign(traveled=lambda d: d["traveled"].mask(d["traveled"] < 1, 0))
# Pivot to one row per subject‐block, but only keep rows where hard,medium,easy >5
dist_df = (
df2.groupby(["experiment_name", "block_start", "subject_name", "patch_difficulty"])[
"traveled"
]
.sum()
.unstack("patch_difficulty", fill_value=0)
.loc[lambda d: (d[["hard", "medium", "easy"]] > 500).all(axis=1)]
.assign(
total_dist=lambda d: d.sum(axis=1),
easy_ratio=lambda d: d["easy"] / d["total_dist"],
)
.reset_index()
)
# Merge in dominance & tag role
dist_df = dist_df.merge(dominance_df, on="experiment_name", how="left").assign(
role=lambda d: np.where(
d["subject_name"] == d["dominant"], "dominant", "subordinate"
)
)
# Box‐plot of easy_ratio by role (dominant first)
fig = px.box(
dist_df,
x="role",
y="easy_ratio",
points="all",
category_orders={"role": ["dominant", "subordinate"]},
title="Easy‐patch distance fraction by role (all wheels >500cm)",
)
fig.update_yaxes(title="Easy‐patch distance / total distance")
fig.update_xaxes(title="")
fig.show()

Patch preference plots#
Load data#
data = load_experiment_data(
data_dir=data_dir,
data_types=["patch", "patchinfo"],
)
patch_df = data["None_patch"]
patch_info_df = data["None_patchinfo"]
block_subject_patch_data_social_combined = patch_df[
patch_df["period"] == "social"
].copy()
block_subject_patch_data_social_combined.drop(columns=["period"], inplace=True)
block_subject_patch_data_social_dict = patch_df_to_dict(
block_subject_patch_data_social_combined
)
block_subject_patch_data_post_social_combined = patch_df[
patch_df["period"] == "postsocial"
].copy()
block_subject_patch_data_post_social_combined.drop(columns=["period"], inplace=True)
block_subject_patch_data_post_social_dict = patch_df_to_dict(
block_subject_patch_data_post_social_combined
)
block_subject_patch_data_social_first_half_combined = get_first_half_social(
patch_df
).copy()
block_subject_patch_data_social_first_half_combined.drop(
columns=["period"], inplace=True
)
block_subject_patch_data_social_first_half_dict = patch_df_to_dict(
block_subject_patch_data_social_first_half_combined
)
patch_info_dict = patch_info_df_to_dict(patch_info_df)
1. Wheel distance spun per block, averaged by the number of mice#
block_subject_patch_data_social_combined["final_wheel_cumsum"] = (
block_subject_patch_data_social_combined["wheel_cumsum_distance_travelled"].apply(
lambda x: x[-1] if isinstance(x, np.ndarray) and len(x) > 0 else 0
)
)
wheel_total_dist_averaged_social = (
block_subject_patch_data_social_combined.groupby("block_start")[
"final_wheel_cumsum"
].sum()
/ 2
)
wheel_total_dist_averaged_social = wheel_total_dist_averaged_social.reset_index()
block_subject_patch_data_post_social_combined["final_wheel_cumsum"] = (
block_subject_patch_data_post_social_combined[
"wheel_cumsum_distance_travelled"
].apply(lambda x: x[-1] if isinstance(x, np.ndarray) and len(x) > 0 else 0)
)
wheel_total_dist_averaged_post_social = (
block_subject_patch_data_post_social_combined.groupby("block_start")[
"final_wheel_cumsum"
]
.sum()
.reset_index()
)
wheel_total_dist_averaged_social["condition"] = "social"
wheel_total_dist_averaged_post_social["condition"] = "post_social"
wheel_total_dist_averaged = pd.concat(
[wheel_total_dist_averaged_social, wheel_total_dist_averaged_post_social]
)
fig = go.Figure()
fig = px.box(
wheel_total_dist_averaged,
x="condition",
y="final_wheel_cumsum",
points="all",
title="Wheel Distance Spun Per Block Averaged By Number Of Subjects",
labels={"final_wheel_cumsum": "Wheel Distance Spun Per Block (cm)"},
)
fig.show()

2. Number of patch switches by each mouse per block#
prob_per_patch_social_first_half_dict = {}
prob_per_patch_social_dict = {}
prob_per_patch_post_social_dict = {}
prob_hard_patch_social_first_half_dict = {}
prob_hard_patch_social_dict = {}
prob_hard_patch_post_social_dict = {}
prob_hard_patch_mean_social_first_half_dict = {}
prob_hard_patch_mean_social_dict = {}
prob_hard_patch_mean_post_social_dict = {}
for exp in experiments:
exp_name = exp["name"]
block_subject_patch_data_social_first_half = (
block_subject_patch_data_social_first_half_dict[exp_name]
)
block_subject_patch_data_social = block_subject_patch_data_social_dict[exp_name]
block_subject_patch_data_post_social = block_subject_patch_data_post_social_dict[
exp_name
]
patch_info = patch_info_dict[exp_name]
# Compute patch probabilities
prob_per_patch_social_first_half_dict[exp_name] = compute_patch_probabilities(
block_subject_patch_data_social_first_half
)
prob_per_patch_social_dict[exp_name] = compute_patch_probabilities(
block_subject_patch_data_social
)
prob_per_patch_post_social_dict[exp_name] = compute_patch_probabilities(
block_subject_patch_data_post_social
)
# Extract hard patch probabilities
prob_hard_patch_social_first_half_dict[exp_name] = extract_hard_patch_probabilities(
prob_per_patch_social_first_half_dict[exp_name], patch_info
)
prob_hard_patch_social_dict[exp_name] = extract_hard_patch_probabilities(
prob_per_patch_social_dict[exp_name], patch_info
)
prob_hard_patch_post_social_dict[exp_name] = extract_hard_patch_probabilities(
prob_per_patch_post_social_dict[exp_name], patch_info
)
# Calculate the mean hard patch probability per pellet number
prob_hard_patch_mean_social_first_half_dict[exp_name] = (
prob_hard_patch_social_first_half_dict[exp_name]
.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
prob_hard_patch_mean_social_dict[exp_name] = (
prob_hard_patch_social_dict[exp_name]
.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
prob_hard_patch_mean_post_social_dict[exp_name] = (
prob_hard_patch_post_social_dict[exp_name]
.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
# Combine the results
prob_hard_patch_social_first_half_combined = pd.concat(
prob_hard_patch_social_first_half_dict.values()
)
prob_hard_patch_social_combined = pd.concat(prob_hard_patch_social_dict.values())
prob_hard_patch_post_social_combined = pd.concat(
prob_hard_patch_post_social_dict.values()
)
prob_hard_patch_mean_social_first_half_combined = (
prob_hard_patch_social_first_half_combined.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
prob_hard_patch_mean_social_combined = (
prob_hard_patch_social_combined.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
prob_hard_patch_mean_post_social_combined = (
prob_hard_patch_post_social_combined.groupby("pellet_number")
.mean(numeric_only=True)
.reset_index()
)
model_social_first_half, y_pred_social_first_half = analyze_patch_probabilities(
prob_hard_patch_mean_social_first_half_combined, "social first half"
)
model_social, y_pred_social = analyze_patch_probabilities(
prob_hard_patch_mean_social_combined, "social"
)
model_post_social, y_pred_post_social = analyze_patch_probabilities(
prob_hard_patch_mean_post_social_combined, "post-social"
)
P-value for the social first half slope: 5.936891089220609e-14
social first half model summary: OLS Regression Results
==============================================================================
Dep. Variable: y R-squared: 0.823
Model: OLS Adj. R-squared: 0.818
Method: Least Squares F-statistic: 153.4
Date: Thu, 21 Aug 2025 Prob (F-statistic): 5.94e-14
Time: 17:37:33 Log-Likelihood: 93.236
No. Observations: 35 AIC: -182.5
Df Residuals: 33 BIC: -179.4
Df Model: 1
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const 0.2611 0.006 43.528 0.000 0.249 0.273
x1 -0.0036 0.000 -12.384 0.000 -0.004 -0.003
==============================================================================
Omnibus: 1.794 Durbin-Watson: 0.708
Prob(Omnibus): 0.408 Jarque-Bera (JB): 1.503
Skew: -0.353 Prob(JB): 0.472
Kurtosis: 2.270 Cond. No. 42.3
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
P-value for the social slope: 2.2319432890179306e-18
social model summary: OLS Regression Results
==============================================================================
Dep. Variable: y R-squared: 0.904
Model: OLS Adj. R-squared: 0.901
Method: Least Squares F-statistic: 311.6
Date: Thu, 21 Aug 2025 Prob (F-statistic): 2.23e-18
Time: 17:37:33 Log-Likelihood: 102.24
No. Observations: 35 AIC: -200.5
Df Residuals: 33 BIC: -197.4
Df Model: 1
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const 0.2490 0.005 53.700 0.000 0.240 0.258
x1 -0.0040 0.000 -17.654 0.000 -0.004 -0.004
==============================================================================
Omnibus: 2.251 Durbin-Watson: 0.657
Prob(Omnibus): 0.325 Jarque-Bera (JB): 1.851
Skew: -0.556 Prob(JB): 0.396
Kurtosis: 2.820 Cond. No. 42.3
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
P-value for the post-social slope: 2.5099330775172165e-07
post-social model summary: OLS Regression Results
==============================================================================
Dep. Variable: y R-squared: 0.558
Model: OLS Adj. R-squared: 0.545
Method: Least Squares F-statistic: 41.73
Date: Thu, 21 Aug 2025 Prob (F-statistic): 2.51e-07
Time: 17:37:33 Log-Likelihood: 84.637
No. Observations: 35 AIC: -165.3
Df Residuals: 33 BIC: -162.2
Df Model: 1
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const 0.2283 0.008 29.774 0.000 0.213 0.244
x1 -0.0024 0.000 -6.460 0.000 -0.003 -0.002
==============================================================================
Omnibus: 0.420 Durbin-Watson: 0.735
Prob(Omnibus): 0.810 Jarque-Bera (JB): 0.556
Skew: 0.003 Prob(JB): 0.757
Kurtosis: 2.383 Cond. No. 42.3
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_first_half_combined["pellet_number"], # [0:35],
y=prob_hard_patch_mean_social_first_half_combined[
"prob_in_hard_patch"
], # [0:35],
mode="lines",
name="First Half of Social Data",
marker=dict(color="blue"),
)
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_combined["pellet_number"], # [0:35],
y=prob_hard_patch_mean_social_combined["prob_in_hard_patch"], # [0:35],
mode="lines",
name="Social Data",
marker=dict(color="red"),
)
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_post_social_combined["pellet_number"], # [0:35],
y=prob_hard_patch_mean_post_social_combined["prob_in_hard_patch"], # [0:35],
mode="lines",
name="Post Social Data",
marker=dict(color="#00CC96"),
)
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_first_half_combined["pellet_number"][0:35],
y=y_pred_social_first_half,
mode="lines",
name="Social First Half Linear Regression Line",
line=dict(dash="dash"),
marker=dict(color="blue"),
)
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_combined["pellet_number"][0:35],
y=y_pred_social,
mode="lines",
name="Social Linear Regression Line",
line=dict(dash="dash"),
marker=dict(color="red"),
)
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_combined["pellet_number"][0:35],
y=y_pred_post_social,
mode="lines",
name="Post Social Linear Regression Line",
line=dict(dash="dash"),
marker=dict(color="#00CC96"),
)
)
fig.update_layout(
title="Probability of being in hard patch over time",
xaxis_title="Pellet number in block",
yaxis_title="Hard patch probability",
)
fig.show()
# Save the figure as an SVG file
# fig.write_image("hard_patch_probability.svg")

# Define the number of rows and columns for the subplot grid
num_experiments = len(experiments)
num_cols = 2
num_rows = (num_experiments + num_cols - 1) // num_cols
# Create a subplot grid
fig = make_subplots(
rows=num_rows, cols=num_cols, subplot_titles=[exp["name"] for exp in experiments]
)
# Iterate over each experiment and add a plot to the grid
for i, exp in enumerate(experiments):
exp_name = exp["name"]
row = (i // num_cols) + 1
col = (i % num_cols) + 1
# Add the plot to the grid
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_first_half_dict[exp_name]["pellet_number"][
0:35
],
y=prob_hard_patch_mean_social_first_half_dict[exp_name][
"prob_in_hard_patch"
][0:35],
mode="lines",
name="First Half of Social Data",
marker=dict(color="blue"),
showlegend=(i == 0),
),
row=row,
col=col,
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_social_dict[exp_name]["pellet_number"][0:35],
y=prob_hard_patch_mean_social_dict[exp_name]["prob_in_hard_patch"][0:35],
mode="lines",
name="Social Data",
marker=dict(color="red"),
showlegend=(i == 0),
),
row=row,
col=col,
)
fig.add_trace(
go.Scatter(
x=prob_hard_patch_mean_post_social_dict[exp_name]["pellet_number"][0:35],
y=prob_hard_patch_mean_post_social_dict[exp_name]["prob_in_hard_patch"][
0:35
],
mode="lines",
name="Post Social Data",
marker=dict(color="#00CC96"),
showlegend=(i == 0),
),
row=row,
col=col,
)
fig.update_layout(height=800, width=1000)
fig.show()

Data overview plot#
Load data#
experiment = experiments[0]
data = load_experiment_data(
experiment=experiment,
data_dir=data_dir,
periods=["social"],
data_types=["patch", "position", "foraging", "weight"],
# trim_days=1 # Opional: trim
)
social_patch_df = data["social_patch"]
social_position_df = data["social_position"]
social_foraging_df = data["social_foraging"]
social_weight_df = data["social_weight"]
1. Position heatmaps over time per subject#
# Copy and sort dataframe
df = social_position_df.copy().sort_index()
# Compute time-of-day and flag dark vs light periods
df["tod"] = (
df.index.hour
+ df.index.minute / 60
+ df.index.second / 3600
+ df.index.microsecond / 1e6 / 3600
)
df["is_dark"] = (df["tod"] >= light_off) & (df["tod"] < light_on)
# Detect light/dark transitions and assign period IDs
first_dark = df.groupby("identity_name")["is_dark"].transform("first")
shifted = df.groupby("identity_name")["is_dark"].shift().fillna(first_dark)
df["light_change"] = df["is_dark"] != shifted
df["light_id"] = df.groupby("identity_name")["light_change"].cumsum().astype(int) + 1
/tmp/ipykernel_1141116/3605859362.py:15: FutureWarning:
Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
# Save the original index name and reset to unique integer indices
original_index_name = df.index.name or "time"
df = df.reset_index()
# Initialize speed column
df["speed"] = 0.0
# Calculate the speed of each subject based on position data
for subject in df["identity_name"].unique():
subject_df = df[df["identity_name"] == subject].copy()
subject_df.sort_values(original_index_name, inplace=True) # sort by time
if subject_df.empty:
continue
# Calculate the difference in position and time
dxy = subject_df[["x", "y"]].diff().values[1:] # skip the first row (NaN)
dt_ms = (
subject_df[original_index_name].diff().dt.total_seconds().values[1:] * 1000
) # convert to milliseconds
# Calculate speed in cm/s
speed = np.linalg.norm(dxy, axis=1) / dt_ms * 1000 / cm2px # convert to cm/s
subject_df["speed"] = np.concatenate(([0], speed)) # add zero for the first row
# Apply a running average filter
k = np.ones(10) / 10 # running avg filter kernel (10 frames)
subject_df["speed"] = np.convolve(subject_df["speed"], k, mode="same")
# Update the original dataframe
mask = df["identity_name"] == subject
df.loc[mask, "speed"] = subject_df["speed"].values
# Set the index back to time
df = df.set_index(original_index_name)
# Remove rows where speed is above 250cm/s
df_filtered = df[df["speed"] <= 250]
df_plot = df_filtered.reset_index()
dark_color = "#555555"
light_color = "#CCCCCC"
subjects = sorted(df_plot["identity_name"].unique())
n_subj = len(subjects)
n_per = int(df_plot["light_id"].max())
# Set up subplots
fig, axes = plt.subplots(
nrows=n_subj,
ncols=n_per,
figsize=(2 * n_per, 2 * n_subj),
sharex=True,
sharey=True,
squeeze=False,
)
# Loop over axes and draw
for i, subj in enumerate(subjects):
for j in range(1, n_per + 1):
ax = axes[i, j - 1]
sub = df_plot[(df_plot.identity_name == subj) & (df_plot.light_id == j)]
if sub.empty:
ax.set_axis_off()
continue
col = dark_color if sub.is_dark.iloc[0] else light_color
plot_in_chunks(
ax, sub.x.values, sub.y.values, linestyle="-", linewidth=0.8, color=col
)
ax.set_aspect("equal", "box")
ax.axis("off")
# Scale bar in bottom-right
last_ax = axes[-1, -1]
length_px = 0.2 * 100 * cm2px
xmin, xmax = last_ax.get_xlim()
ymin, ymax = last_ax.get_ylim()
x1 = xmax - 0.02 * (xmax - xmin)
x0 = x1 - length_px
y0 = ymin + 0.02 * (ymax - ymin)
last_ax.plot([x0, x1], [y0, y0], "k-", lw=2)
last_ax.text(
(x0 + x1) / 2,
y0 - 0.06 * (ymax - ymin),
"0.2 m",
va="top",
ha="center",
fontsize=8,
color="k",
)
plt.tight_layout()
# svg_path = save_dir / "position_maps.pdf"
# plt.savefig(svg_path, format="pdf", dpi=300)
plt.show()

2. Locomotion speed over time per subject#
df_speed = df_filtered.copy()
df_speed.index.name = "time"
dark_color = "#555555"
light_color = "#CCCCCC"
subjects = sorted(df_speed["identity_name"].unique())
agg = (
df_speed[["identity_name", "speed"]]
.groupby("identity_name")
.resample("1s")
.mean()
.dropna()
.reset_index()
)
fig = go.Figure()
for subj in subjects:
sub = agg[agg["identity_name"] == subj]
fig.add_trace(
go.Scatter(
x=sub["time"],
y=sub["speed"],
mode="lines",
name=subj,
line=dict(width=1),
opacity=1,
)
)
fig.update_layout(
title="Speed over Time",
xaxis_title="Time",
yaxis_title="Speed",
legend_title="Subject",
plot_bgcolor="white",
)
# pio.write_image(
# fig,
# str(save_dir / "speed_over_time.svg"),
# format="svg"
# )
fig.show()

3. Foraging bouts raster plot#
dark_color = "#555555"
subjects = sorted(social_foraging_df["subject"].unique())
fig = px.timeline(
social_foraging_df,
x_start="start",
x_end="end",
y="subject",
hover_data=["n_pellets", "cum_wheel_dist"],
category_orders={"subject": subjects},
)
fig.update_traces(
opacity=1,
marker_color=dark_color,
marker_line_color=dark_color,
marker_line_width=1.5,
)
fig.update_layout(
template="simple_white",
plot_bgcolor="white",
margin=dict(l=150, r=20, t=20, b=20),
height=max(100, len(subjects) * 25 + 50),
xaxis=dict(
showgrid=False, zeroline=False, showline=False, ticks="", showticklabels=False
),
yaxis=dict(
showgrid=False,
zeroline=False,
showline=False,
ticks="",
showticklabels=True,
title="",
),
)
# pio.write_image(
# fig,
# str(save_dir / "foraging_bouts_raster.svg"),
# format="svg"
# )
fig.show()

4. Wheel distance spun#
Over time per subject patch#
dt_seconds = 0.02
fig = go.Figure()
# Build each trace (continuous + Δ>0.5 downsample)
max_y = 0
for (subject, patch), grp in social_patch_df.groupby(["subject_name", "patch_name"]):
grp = grp.sort_values("block_start")
total_n = sum(len(a) for a in grp.wheel_cumsum_distance_travelled)
times = np.empty(total_n, dtype="datetime64[ns]")
dists = np.empty(total_n, dtype=float)
idx, offset = 0, 0.0
for bs_val, arr in zip(grp.block_start, grp.wheel_cumsum_distance_travelled):
arr = np.asarray(arr)
n = arr.size
offs = (np.arange(n) * dt_seconds * 1e9).astype("timedelta64[ns]")
times[idx : idx + n] = np.datetime64(bs_val) + offs
dists[idx : idx + n] = arr + offset
offset += arr[-1]
idx += n
max_y = max(max_y, dists.max())
# Downsample
diffs = np.abs(np.diff(dists, prepend=dists[0]))
mask = diffs > 0.5
mask[0] = True
fig.add_trace(
go.Scatter(
x=times[mask],
y=dists[mask] / 100, # convert to meters
mode="lines",
name=f"{subject} — {patch}",
line=dict(width=1.5),
)
)
# Hide all ticks and tick lines; keep only the y-axis title
fig.update_xaxes(
showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks=""
)
fig.update_yaxes(
title_text="Distance spun on wheel (m)",
showgrid=False,
zeroline=False,
showline=False,
showticklabels=False,
ticks="",
)
# Vertical scale bar: 250 m high at the right edge
fig.add_shape(
type="line",
xref="paper",
x0=1.02,
x1=1.02,
yref="y",
y0=0,
y1=250,
line=dict(color="black", width=2),
)
fig.add_annotation(
xref="paper",
x=1.04,
y=125,
text="250",
showarrow=False,
xanchor="left",
yanchor="middle",
font=dict(size=10, color="black"),
)
# Layout tweaks and show
fig.update_layout(
template="simple_white", margin=dict(l=20, r=80, t=20, b=20), showlegend=True
)
# pio.write_image(
# fig,
# str(save_dir / "wheel_dist_over_time.svg"),
# format="svg"
# )
fig.show(
config={
"staticPlot": True,
}
)

Dummy patch vs normal patches#
patch_dfs = []
for exp in [experiments[i] for i in [4, 5]]:
data = load_experiment_data(experiment=exp, data_dir=data_dir, data_types=["patch"])
df = data["None_patch"]
patch_dfs.append(df)
patch_df_s4 = pd.concat(patch_dfs).sort_index()
dt_seconds = 0.02
summary = []
# Compute total distance spun per (subject, patch)
for (subject, patch), grp in patch_df_s4.groupby(["subject_name", "patch_name"]):
total_distance = sum(
np.asarray(w)[-1] for w in grp["wheel_cumsum_distance_travelled"] if len(w) > 0
)
summary.append(
{"subject": subject, "patch": patch, "distance": total_distance / 100}
) # convert to meters
summary_df = pd.DataFrame(summary)
# Pivot to subject x patch format
pivot_df = summary_df.pivot(index="subject", columns="patch", values="distance").fillna(
0
)
# Per-subject values for PatchDummy1 and mean of other patches
patch_dummy1_vals = pivot_df.get("PatchDummy1", pd.Series(0, index=pivot_df.index))
other_patch_vals = pivot_df.drop(columns="PatchDummy1", errors="ignore").mean(axis=1)
# Compute means and SEMs
means = [patch_dummy1_vals.mean(), other_patch_vals.mean()]
sems = [patch_dummy1_vals.sem(), other_patch_vals.sem()]
labels = ["PatchDummy1", "Avg Other Patches"]
# Plot
fig = go.Figure(
data=[
go.Bar(
x=labels,
y=means,
error_y=dict(type="data", array=sems, visible=True),
marker_color=["#636EFA", "#EF553B"],
)
]
)
fig.update_layout(
title="Mean Wheel Distance Spun ± SEM",
yaxis_title="Distance (m, log)",
xaxis_title="Patch",
yaxis_type="log",
template="simple_white",
height=400,
)
# pio.write_image(
# fig,
# str(save_dir / "dummy_vs_normal_patches.svg"),
# format="svg"
# )
fig.show()

display(patch_dummy1_vals, other_patch_vals)
t_stat, p_val = ttest_rel(patch_dummy1_vals, other_patch_vals)
print(f"Paired t‐statistic = {t_stat:.3f}, p‐value = {p_val:.4f}")
subject
BAA-1104792 25.288925
BAA-1104794 20.960074
BAA-1104795 24.882764
BAA-1104797 6.401984
Name: PatchDummy1, dtype: float64
subject
BAA-1104792 4300.985857
BAA-1104794 3847.328388
BAA-1104795 3551.048284
BAA-1104797 3306.791160
dtype: float64
Paired t‐statistic = -17.706, p‐value = 0.0004
5. Pellets raster plot#
df = (
social_patch_df[["subject_name", "patch_name", "pellet_timestamps"]]
.explode("pellet_timestamps")
.dropna(subset=["pellet_timestamps"])
)
df["pellet_timestamps"] = pd.to_datetime(df["pellet_timestamps"])
df["patch_idx"] = df["patch_name"].str.extract(r"(\d+)$").astype(int)
dark_color = "#555555"
subjects = sorted(df["subject_name"].unique())
n_subj = len(subjects)
fig = make_subplots(rows=n_subj, cols=1, shared_xaxes=True, subplot_titles=subjects)
for i, subj in enumerate(subjects, start=1):
sub = df[df["subject_name"] == subj]
fig.add_trace(
go.Scatter(
x=sub["pellet_timestamps"],
y=sub["patch_idx"],
mode="markers",
marker=dict(symbol="line-ns", color=dark_color, size=8, line_width=1.2),
showlegend=False,
),
row=i,
col=1,
)
# Y-axis = patch numbers
fig.update_yaxes(
row=i,
col=1,
title="Patches",
tickmode="array",
showgrid=False,
zeroline=False,
showline=False,
ticks="",
)
# Remove x-labels on every row
fig.update_xaxes(
row=i,
col=1,
showticklabels=False,
showgrid=False,
zeroline=False,
showline=False,
ticks="",
)
# Tighten the margins and overall height
fig.update_layout(
template="simple_white",
plot_bgcolor="white",
margin=dict(l=80, r=20, t=80, b=20),
height=120 * n_subj,
)
# Center subject titles
for ann in fig.layout.annotations:
ann.x = 0.5
ann.xanchor = "center"
ann.font = dict(size=16)
# pio.write_image(
# fig,
# str(save_dir / "pellets_raster.svg"),
# format="svg"
# )
fig.show()

6. Weight over time#
Per subject for a single experiment#
# Keep only rows that have timestamps
df = social_weight_df[
social_weight_df["timestamps"].apply(lambda lst: len(lst) > 0)
].copy()
# Explode parallel list-columns into long form
df = df.explode(["timestamps", "weight", "subject_id"])
# Type conversions
df["timestamps"] = pd.to_datetime(df["timestamps"])
df["weight"] = df["weight"].astype(float)
# Plot
fig = px.line(
df,
x="timestamps",
y="weight",
color="subject_id",
markers=True,
title="Subject Weights Over Time",
labels={"timestamps": "Time", "weight": "Weight", "subject_id": "Subject ID"},
)
fig.update_layout(legend_title_text="Subject ID")
fig.show()

Averaged over time#
social_weight_dfs = []
for exp in [experiments[i] for i in [0, 1, 4, 5]]:
data = load_experiment_data(
experiment=exp, data_dir=data_dir, periods=["social"], data_types=["weight"]
)
df = data["social_weight"]
social_weight_dfs.append(df)
social_weight_df_all_exps = pd.concat(social_weight_dfs).sort_index()
# Flatten the “array-of-arrays” df into a long DataFrame
records = []
for _, row in social_weight_df_all_exps.iterrows():
ts = pd.to_datetime(row["timestamps"])
w = np.asarray(row["weight"], dtype=float)
sids = row["subject_id"]
for t, weight, sid in zip(ts, w, sids):
records.append({"timestamp": t, "subject_id": sid, "weight": weight})
df = pd.DataFrame.from_records(records)
df = df.sort_values(["subject_id", "timestamp"])
# Drop rows where subject_id is shorter than 11 characters, removes some erroneous entries
df = df[df["subject_id"].str.len() >= 11].reset_index(drop=True)
# Parameters & 24 h cycle grid
sampling_freq = "10min" # resample interval
# Choose any date at 08:00 to define “day zero”
anchor = pd.Timestamp(f"2020-01-01 {light_off:02d}:00:00")
# Build the 24 h cycle index
n_steps = int(pd.Timedelta("1D") / pd.Timedelta(sampling_freq))
cycle_index = pd.timedelta_range(start=0, periods=n_steps, freq=sampling_freq)
# Will hold each mouse’s mean‐day
cycle_df = pd.DataFrame(index=cycle_index)
# Loop over each mouse, resample, fold into 24 h, average days
for sid, grp in df.groupby("subject_id"):
# Series of weight vs time
ser = grp.set_index("timestamp")["weight"].sort_index()
# Collapse any exact-duplicate timestamps
ser = ser.groupby(level=0).mean()
# Resample into bins anchored at 08:00, then interpolate
ser_rs = ser.resample(sampling_freq, origin=anchor).mean().interpolate()
# Convert each timestamp into its offset (mod 24 h) from the anchor
offsets = (ser_rs.index - anchor) % pd.Timedelta("1D")
ser_rs.index = offsets
# Average across all days for each offset
daily = ser_rs.groupby(ser_rs.index).mean()
# Align to uniform cycle grid
cycle_df[sid] = daily.reindex(cycle_index)
# Baseline‐subtract each mouse’s minimum, then grand‐mean
cycle_df_baselined = cycle_df.subtract(cycle_df.min(skipna=True), axis=1)
grand_mean = cycle_df_baselined.mean(axis=1)
sem = cycle_df_baselined.sem(axis=1)
# Smooth both mean and SEM with a centered rolling window
window = 20
grand_mean_smooth = grand_mean.rolling(window=window, center=True, min_periods=1).mean()
sem_smooth = sem.rolling(window=window, center=True, min_periods=1).mean()
# Plot each subject's mean-day curve in its own subplot
n_subj = cycle_df_baselined.shape[1]
n_cols = 4 # adjust as needed
n_rows = math.ceil(n_subj / n_cols)
fig, axes = plt.subplots(
n_rows, n_cols, figsize=(3.5 * n_cols, 2.5 * n_rows), sharex=True, sharey=True
)
axes = axes.flatten()
x_hours = cycle_df_baselined.index.total_seconds() / 3600
for i, sid in enumerate(cycle_df_baselined.columns):
ax = axes[i]
y = cycle_df_baselined[sid]
y_smooth = y.rolling(window=window, center=True, min_periods=1).mean()
ax.plot(x_hours, y_smooth, lw=1.5)
ax.set_title(sid, fontsize=9)
ax.set_xlim(0, 24)
ax.grid(True, linestyle=":", linewidth=0.5)
# Remove any unused axes
for ax in axes[n_subj:]:
ax.set_visible(False)
fig.suptitle("24 h weight cycle (baseline-subtracted) per subject", y=1.02)
fig.text(0.5, 0.04, f"Time since {light_off:02d}:00:00 (hours)", ha="center")
fig.text(0.04, 0.5, "Weight (baseline-subtracted)", va="center", rotation="vertical")
plt.subplots_adjust(hspace=0.4, wspace=0.3, bottom=0.1, left=0.07, right=0.97, top=0.90)
plt.show()

Averaged over time and subjects#
Baselined by subtracting the minimum of each subject’s 24h mean-day curve
# Flatten the “array-of-arrays” df into a long DataFrame
records = []
for _, row in social_weight_df_all_exps.iterrows():
ts = pd.to_datetime(row["timestamps"])
w = np.asarray(row["weight"], dtype=float)
sids = row["subject_id"]
for t, weight, sid in zip(ts, w, sids):
records.append({"timestamp": t, "subject_id": sid, "weight": weight})
df = pd.DataFrame.from_records(records)
df = df.sort_values(["subject_id", "timestamp"])
# Drop rows where subject_id is shorter than 11 characters, removes some erroneous entries
df = df[df["subject_id"].str.len() >= 11].reset_index(drop=True)
# Parameters & 24 h cycle grid
sampling_freq = "10min" # resample interval
# Choose any date at 08:00 to define “day zero”
anchor = pd.Timestamp(f"2020-01-01 {light_off:02d}:00:00")
# Build the 24 h cycle index
n_steps = int(pd.Timedelta("1D") / pd.Timedelta(sampling_freq))
cycle_index = pd.timedelta_range(start=0, periods=n_steps, freq=sampling_freq)
# Will hold each mouse’s mean‐day
cycle_df = pd.DataFrame(index=cycle_index)
# Loop over each mouse, resample, fold into 24 h, average days
for sid, grp in df.groupby("subject_id"):
# Series of weight vs time
ser = grp.set_index("timestamp")["weight"].sort_index()
# Collapse any exact-duplicate timestamps
ser = ser.groupby(level=0).mean()
# Resample into bins anchored at 08:00, then interpolate
ser_rs = ser.resample(sampling_freq, origin=anchor).mean().interpolate()
# Convert each timestamp into its offset (mod 24 h) from the anchor
offsets = (ser_rs.index - anchor) % pd.Timedelta("1D")
ser_rs.index = offsets
# Average across all days for each offset
daily = ser_rs.groupby(ser_rs.index).mean()
# Align to uniform cycle grid
cycle_df[sid] = daily.reindex(cycle_index)
# Baseline‐subtract each mouse’s minimum, then grand‐mean
cycle_df_baselined = cycle_df.subtract(cycle_df.min())
grand_mean = cycle_df_baselined.mean(axis=1)
sem = cycle_df_baselined.sem(axis=1)
# Smooth both mean and SEM with a centered rolling window
window = 20
grand_mean_smooth = grand_mean.rolling(window=window, center=True, min_periods=1).mean()
sem_smooth = sem.rolling(window=window, center=True, min_periods=1).mean()
# Plot mean ± SEM
plt.figure(figsize=(10, 4))
x_hours = cycle_df_baselined.index.total_seconds() / 3600
plt.plot(x_hours, grand_mean_smooth, lw=2, label="Mean weight")
plt.fill_between(
x_hours,
grand_mean_smooth - sem_smooth,
grand_mean_smooth + sem_smooth,
alpha=0.3,
label="± SEM",
)
plt.xlabel(f"Time since {light_off:02d}:00:00 (hours)")
plt.ylabel("Weight (baseline-subtracted)")
plt.title("Average 24 h weight cycle across all mice\n(Mean ± SEM)")
plt.xlim(0, 24)
plt.legend()
plt.tight_layout()
plt.show()

Baselined by subtracting the minimum of each subject’s smoothed 24h mean-day curve
# Flatten the “array-of-arrays” df into a long DataFrame
records = []
for _, row in social_weight_df_all_exps.iterrows():
ts = pd.to_datetime(row["timestamps"])
w = np.asarray(row["weight"], dtype=float)
sids = row["subject_id"]
for t, weight, sid in zip(ts, w, sids):
records.append({"timestamp": t, "subject_id": sid, "weight": weight})
df = pd.DataFrame.from_records(records)
df = df.sort_values(["subject_id", "timestamp"])
# Drop rows where subject_id is shorter than 11 characters, removes some erroneous entries
df = df[df["subject_id"].str.len() >= 11].reset_index(drop=True)
# Parameters & 24 h cycle grid
sampling_freq = "10min" # resample interval
# Choose any date at 08:00 to define “day zero”
anchor = pd.Timestamp(f"2020-01-01 {light_off:02d}:00:00")
# Build the 24 h cycle index
n_steps = int(pd.Timedelta("1D") / pd.Timedelta(sampling_freq))
cycle_index = pd.timedelta_range(start=0, periods=n_steps, freq=sampling_freq)
# Will hold each mouse’s mean‐day (unsmoothed)
cycle_df = pd.DataFrame(index=cycle_index)
# Loop over each mouse, resample, fold into 24 h, average days
for sid, grp in df.groupby("subject_id"):
# Series of weight vs time
ser = grp.set_index("timestamp")["weight"].sort_index()
# Collapse any exact-duplicate timestamps
ser = ser.groupby(level=0).mean()
# Resample into bins anchored at 08:00, then interpolate
ser_rs = ser.resample(sampling_freq, origin=anchor).mean().interpolate()
# Convert each timestamp into its offset (mod 24 h) from the anchor
offsets = (ser_rs.index - anchor) % pd.Timedelta("1D")
ser_rs.index = offsets
# Average across all days for each offset
daily = ser_rs.groupby(ser_rs.index).mean()
# Align to uniform cycle grid
cycle_df[sid] = daily.reindex(cycle_index)
# Baseline‐subtract using each subject’s smoothed minimum
window = 20 # keep same smoothing window for per‐subject curves
# Smooth each column (i.e., each subject’s 24 h curve)
cycle_df_smooth = cycle_df.rolling(window=window, center=True, min_periods=1).mean()
# Find the minimum of each subject’s smoothed curve
minima_smooth = cycle_df_smooth.min() # series indexed by subject_id
# Subtract that baseline from the UNSMOOTHED cycle for each subject
cycle_df_baselined = cycle_df.subtract(minima_smooth, axis=1)
# Now compute grand‐mean and SEM (over baselined, unsmoothed curves)
grand_mean = cycle_df_baselined.mean(axis=1)
sem = cycle_df_baselined.sem(axis=1)
# Smooth both grand‐mean and SEM for plotting
grand_mean_smooth = grand_mean.rolling(window=window, center=True, min_periods=1).mean()
sem_smooth = sem.rolling(window=window, center=True, min_periods=1).mean()
# Plot mean ± SEM
plt.figure(figsize=(10, 4))
x_hours = cycle_df_baselined.index.total_seconds() / 3600
plt.plot(x_hours, grand_mean_smooth, lw=2, label="Mean weight (baselined)")
plt.fill_between(
x_hours,
grand_mean_smooth - sem_smooth,
grand_mean_smooth + sem_smooth,
alpha=0.3,
label="± SEM",
)
plt.xlabel(f"Time since {light_off:02d}:00:00 (hours)")
plt.ylabel("Weight (baseline‐subtracted)")
plt.title("Average 24 h weight cycle across all mice\n(Mean ± SEM)")
plt.xlim(0, 24)
plt.legend()
plt.tight_layout()
# svg_path = save_dir / "24h_weight_cycle.svg"
# plt.savefig(str(svg_path), format='svg')
plt.show()
