Platform paper social experiment analysis: Part 2#
In this example we will work with behavioural data collected from experiments social0.2
, social0.3
, and social0.4
, in which two mice foraged for food in the habitat with three foraging patches whose reward rates changed dynamically over time.
The experiments each consist of three periods:
“presocial”, in which each mouse was in the habitat alone for 3-4 days.
“social”, in which both mice were in the habitat together for 2 weeks.
“postsocial”, in which each mouse was in the habitat alone again for 3-4 days.
The goal of the experiments was to understand how the mice’s behaviour changes as they learn to forage for food in the habitat, and how their behaviour differs between social vs. solo settings.
The full datasets are available on the Datasets page but for the purpose of this example, we will be using the precomputed Platform paper social analysis datasets.
See also
“Extended Data Fig. 7”, in “Extended Data” in the “Supplementary Material” of the platform paper for a detailed description of the experiments.
Below is a brief explanation of how the environment (i.e. patch properties) changed over blocks (60–180 minute periods of time):
Every block begins at a random interval \(t\):
\[ t \sim \mathrm{Uniform}(60,\,180) \quad \text{In minutes} \]At the start of each block, sample a row from the predefined matrix \(\lambda_{\mathrm{set}}\):
\[\begin{split} \lambda_{\mathrm{set}} = \begin{pmatrix} 1 & 1 & 1 \\ 5 & 5 & 5 \\ 1 & 3 & 5 \\ 1 & 5 & 3 \\ 3 & 1 & 5 \\ 3 & 5 & 1 \\ 5 & 1 & 3 \\ 5 & 3 & 1 \\ \end{pmatrix} \quad \text{In meters} \end{split}\]Assign the sampled row to specific patch means \(\lambda_{\mathrm{1}}, \lambda_{\mathrm{2}}, \lambda_{\mathrm{3}}\) and apply a constant offset \(c\) to all thresholds:
\[\begin{split} \begin{aligned} \lambda_{\mathrm{1}}, \lambda_{\mathrm{2}}, \lambda_{\mathrm{3}} &\sim \mathrm{Uniform}(\lambda_{\mathrm{set}}) \\ c &= 0.75 \end{aligned} \quad \text{Patch means and offset} \end{split}\]Sample a value from each of \(P_{\mathrm{1}}, P_{\mathrm{2}}, P_{\mathrm{3}}\) as the initial threshold for the respective patch. Whenever a patch reaches its threshold, resample a new value from its corresponding distribution:
\[\begin{split} \begin{aligned} P_{\mathrm{1}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{1}}) \\ P_{\mathrm{2}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{2}}) \\ P_{\mathrm{3}} &= c + \mathrm{Exp}(1/\lambda_{\mathrm{3}}) \end{aligned} \quad \text{Patch distributions} \end{split}\]
Set up environment#
Create and activate a virtual environment named social-analysis
using uv.
uv venv aeon-social-analysis --python ">=3.11"
source aeon-social-analysis/bin/activate # Unix
.\aeon-social-analysis\Scripts\activate # Windows
Install the required ssm
package and its dependencies.
uv pip install matplotlib numpy pandas plotly seaborn statsmodels pyyaml pyarrow tqdm scipy jupyter
Import libraries and define variables and helper functions#
"""Notebook settings and imports"""
from pathlib import Path
from warnings import warn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import seaborn as sns
from scipy import stats
from scipy.ndimage import uniform_filter1d
from tqdm.auto import tqdm
Show code cell source
Hide code cell source
# Plot settings
sns.set_style("whitegrid")
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["xtick.labelsize"] = 15
plt.rcParams["ytick.labelsize"] = 15
plt.rcParams["legend.title_fontsize"] = 15
plt.rcParams["legend.fontsize"] = 14
# Constants
cm2px = 5.2 # 1 cm = 5.2 px roughly in aeon arenas
light_off, light_on = 7, 20 # 7am to 8pm
subject_colors = plotly.colors.qualitative.Dark24
patch_colors = plotly.colors.qualitative.Light24
patch_markers = [
"circle",
"bowtie",
"square",
"hourglass",
"diamond",
"cross",
"x",
"triangle",
"star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = dict(zip(patch_markers, patch_markers_symbols, strict=False))
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]
subject_markers_linestyles = patch_markers_linestyles.copy()
patch_type_mean_map = {100: "l", 300: "m", 500: "h", 200: "l", 600: "m", 1000: "h"}
patch_type_rate_map = {
0.01: "l",
0.0033: "m",
0.002: "h",
0.005: "l",
0.00167: "m",
0.001: "h",
}
experiments = [
{
"name": "social0.2-aeon3",
"presocial_start": "2024-01-31 11:00:00",
"presocial_end": "2024-02-08 15:00:00",
"social_start": "2024-02-09 16:00:00",
"social_end": "2024-02-23 13:00:00",
"postsocial_start": "2024-02-25 17:00:00",
"postsocial_end": "2024-03-02 14:00:00",
},
{
"name": "social0.2-aeon4",
"presocial_start": "2024-01-31 11:00:00",
"presocial_end": "2024-02-08 15:00:00",
"social_start": "2024-02-09 17:00:00",
"social_end": "2024-02-23 12:00:00",
"postsocial_start": "2024-02-25 18:00:00",
"postsocial_end": "2024-03-02 13:00:00",
},
{
"name": "social0.3-aeon3",
"presocial_start": "2024-06-08 19:00:00",
"presocial_end": "2024-06-17 13:00:00",
"social_start": "2024-06-25 11:00:00",
"social_end": "2024-07-06 13:00:00",
"postsocial_start": "2024-07-07 16:00:00",
"postsocial_end": "2024-07-14 14:00:00",
},
{
"name": "social0.3-aeon4",
},
{
"name": "social0.4-aeon3",
"presocial_start": "2024-08-16 17:00:00",
"presocial_end": "2024-08-24 10:00:00",
"social_start": "2024-08-28 11:00:00",
"social_end": "2024-09-09 13:00:00",
"postsocial_start": "2024-09-09 18:00:00",
"postsocial_end": "2024-09-22 16:00:00",
},
{
"name": "social0.4-aeon4",
"presocial_start": "2024-08-16 15:00:00",
"presocial_end": "2024-08-24 10:00:00",
"social_start": "2024-08-28 10:00:00",
"social_end": "2024-09-09 01:00:00",
"postsocial_start": "2024-09-09 15:00:00",
"postsocial_end": "2024-09-22 16:00:00",
},
]
periods = ["social", "postsocial"]
# Define the possible combos of social and light
combos = [
(True, True), # Social + Light
(True, False), # Social + Dark
(False, True), # Solo + Light
(False, False), # Solo + Dark
]
# Define colors based on light condition (light=blue, dark=orange)
colors = {
True: "#1f77b4", # Blue for light conditions
False: "#ff7f0e", # Orange for dark conditions
}
# Define hatch patterns based on social condition
hatches = {
True: "///", # Hatched pattern for social
False: None, # No pattern (solid) for solo
}
labels = ["Social-Light", "Social-Dark", "Solo-Light", "Solo-Dark"]
Show code cell source
Hide code cell source
def load_data_from_parquet(
experiment_name: str | None,
period: str | None,
data_type: str,
data_dir: Path,
set_time_index: bool = False,
) -> pd.DataFrame:
"""Loads saved data from parquet files.
Args:
experiment_name (str, optional): Filter by experiment name. If None, load all experiments.
period (str, optional): Filter by period (presocial, social, postsocial). If None, load all periods.
data_type (str): Type of data to load (position, patch, foraging, rfid, sleep, explore)
data_dir (Path): Directory containing parquet files.
set_time_index (bool, optional): If True, set 'time' column as DataFrame index.
Returns:
pd.DataFrame: Combined DataFrame of all matching parquet files.
"""
if not data_dir.exists():
print(f"Directory {data_dir} does not exist. No data files found.")
return pd.DataFrame()
# Create pattern based on filters
pattern = ""
if experiment_name:
pattern += f"{experiment_name}_"
else:
pattern += "*_"
if period:
pattern += f"{period}_"
else:
pattern += "*_"
pattern += f"{data_type}.parquet"
# Find matching files
matching_files = list(data_dir.glob(pattern))
if not matching_files:
print(f"No matching data files found with pattern: {pattern}")
return pd.DataFrame()
print(f"Found {len(matching_files)} matching files")
# Load and concatenate matching files
dfs = []
total_rows = 0
for file in matching_files:
print(f"Loading {file}...")
df = pd.read_parquet(file)
total_rows += len(df)
dfs.append(df)
print(f" Loaded {len(df)} rows")
# Combine data
if dfs:
combined_df = pd.concat(dfs, ignore_index=True)
if set_time_index and "time" in combined_df.columns:
combined_df = combined_df.set_index("time")
print(f"Combined data: {len(combined_df)} rows")
return combined_df
else:
return pd.DataFrame()
def load_experiment_data(
data_dir: Path,
experiment: dict | None = None,
periods: list | None = None,
data_types: list[str] = ["rfid", "position"],
trim_days: int | None = None,
) -> dict:
"""Load all data types for specified periods of an experiment.
Parameters:
- experiment: experiment dict with period start/end times
- periods: list of periods to load
- data_types: list of data types to load
- data_dir: directory containing data files
- trim_days: Optional number of days to trim from start (None = no trim)
Returns:
- Dictionary containing dataframes for each period/data type combination
"""
result = {}
if periods is None:
periods = [None]
for period in periods:
for data_type in data_types:
print(f"Loading {period} {data_type} data...")
# Load data
experiment_name = experiment["name"] if experiment is not None else None
df = load_data_from_parquet(
experiment_name=experiment_name,
period=period,
data_type=data_type,
data_dir=data_dir,
set_time_index=(data_type == "position"),
)
# Trim if requested
if trim_days is not None and len(df) > 0:
if data_type == "rfid":
start_time = df["chunk_start"].min()
end_time = start_time + pd.Timedelta(days=trim_days)
df = df[df["chunk_start"] < end_time]
if data_type == "foraging":
start_time = df["start"].min()
end_time = start_time + pd.Timedelta(days=trim_days)
df = df[df["start"] < end_time]
if data_type == "position":
start_time = df.index.min()
end_time = start_time + pd.Timedelta(days=trim_days)
df = df.loc[df.index < end_time]
print(f" Trimmed to {trim_days} days: {len(df)} records")
# Store in result
key = f"{period}_{data_type}"
result[key] = df
# For position data, handle duplicates
if data_type == "position" and len(df) > 0:
original_len = len(df)
df = df.reset_index()
df = df.drop_duplicates(subset=["time", "identity_name"])
df = df.set_index("time")
result[key] = df
if len(df) < original_len:
print(f" Removed duplicates: {original_len} -> {len(df)}")
return result
Note
Change data_dir
and save_dir
to the paths where your local dataset (the parquet files) is stored and where you want to save the results.
# SET THESE VARIABLES ACCORDINGLY
data_dir = Path("")
save_dir = Path("")
Solo vs. Social Behaviours#
Exploring#
Distance travelled#
# Final df:
# rows = hour-datetime,
# columns = distance, exp, social-bool, subject, light-bool
dist_trav_hour_df = pd.DataFrame(
columns=["hour", "distance", "exp", "social", "subject", "light"]
)
# For each period
# Load pos data
# Split into individual dfs
# If social, excise swaps
# Smooth down to 1s
# Calculate hour-by-hour distance traveled, and put into final df
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
period_pbar = tqdm(periods, desc="Periods", position=1, leave=True)
for period in period_pbar:
pos_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="position",
data_dir=data_dir,
set_time_index=True,
)
for subject in pos_df["identity_name"].unique():
pos_df_subj = pos_df[pos_df["identity_name"] == subject]
pos_df_subj = pos_df_subj.resample("200ms").first().dropna(subset=["x"])
pos_df_subj[["x", "y"]] = pos_df_subj[["x", "y"]].rolling("1s").mean()
pos_df_subj = pos_df_subj.resample("1s").first().dropna(subset=["x"])
pos_df_subj["distance"] = np.sqrt(
(pos_df_subj["x"].diff() ** 2) + (pos_df_subj["y"].diff() ** 2)
)
pos_df_subj.at[pos_df_subj.index[0], "distance"] = 0
pos_df_subj["distance"] /= cm2px * 100 # convert to m
pos_df_subj["hour"] = pos_df_subj.index.floor("h")
pos_df_subj_hour = (
pos_df_subj.groupby("hour")["distance"].sum().reset_index()
)
pos_df_subj_hour["exp"] = exp["name"]
pos_df_subj_hour["social"] = period == "social"
pos_df_subj_hour["subject"] = subject
hour = pos_df_subj_hour["hour"].dt.hour
pos_df_subj_hour["light"] = ~((hour > light_off) & (hour < light_on))
dist_trav_hour_df = pd.concat(
[dist_trav_hour_df, pos_df_subj_hour], ignore_index=True
)
# # Save as parquet
# dist_trav_hour_df.to_parquet(
# data_dir / "for_plots" / "dist_trav_hour_df.parquet",
# engine="pyarrow",
# compression="snappy",
# index=False,
# )
# Load the parquet file
dist_trav_hour_df = pd.read_parquet(
data_dir / "for_plots" / "dist_trav_hour_df.parquet",
engine="pyarrow",
)
display(dist_trav_hour_df)
hour | distance | exp | social | subject | light | |
---|---|---|---|---|---|---|
0 | 2024-01-31 11:00:00 | 214.375787 | social0.2-aeon3 | False | BAA-1104045 | False |
1 | 2024-01-31 12:00:00 | 358.672416 | social0.2-aeon3 | False | BAA-1104045 | False |
2 | 2024-01-31 13:00:00 | 301.952548 | social0.2-aeon3 | False | BAA-1104045 | False |
3 | 2024-01-31 14:00:00 | 284.154738 | social0.2-aeon3 | False | BAA-1104045 | False |
4 | 2024-01-31 15:00:00 | 420.268372 | social0.2-aeon3 | False | BAA-1104045 | False |
... | ... | ... | ... | ... | ... | ... |
4840 | 2024-09-22 13:00:00 | 263.921865 | social0.4-aeon4 | False | BAA-1104797 | False |
4841 | 2024-09-22 14:00:00 | 316.511526 | social0.4-aeon4 | False | BAA-1104797 | False |
4842 | 2024-09-22 15:00:00 | 281.001766 | social0.4-aeon4 | False | BAA-1104797 | False |
4843 | 2024-09-22 16:00:00 | 171.733688 | social0.4-aeon4 | False | BAA-1104797 | False |
4844 | 2024-09-22 17:00:00 | 0.000905 | social0.4-aeon4 | False | BAA-1104797 | False |
4845 rows × 6 columns
"""Hists."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
# Filter data for this combination
subset = dist_trav_hour_df[
(dist_trav_hour_df["social"] == social_val)
& (dist_trav_hour_df["light"] == light_val)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="distance",
stat="probability", # This normalizes the histogram
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=20,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title(
"Normalized Distance Traveled Distributions by Social and Light Conditions"
)
ax.set_xlabel("Distance Traveled (m / h)")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_ylim(0, 0.2)
(0.0, 0.2)

"""Bars."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = dist_trav_hour_df[
(dist_trav_hour_df["social"] == social_val)
& (dist_trav_hour_df["light"] == light_val)
]
mean_dist = subset["distance"].mean()
sem_dist = subset["distance"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_distance": mean_dist,
"sem": sem_dist,
"condition": (
f"{'Social' if social_val else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}",
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_distance"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_distance']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_distance"] + row["sem"] + 5,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_ylabel("Mean Distance Traveled (m / h)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.set_title("Mean Distance Traveled by Social and Light Conditions")
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
# Add stats tests
light_social = dist_trav_hour_df[
(dist_trav_hour_df["social"] == True) & (dist_trav_hour_df["light"] == True)
]["distance"]
light_solo = dist_trav_hour_df[
(dist_trav_hour_df["social"] == False) & (dist_trav_hour_df["light"] == True)
]["distance"]
dark_social = dist_trav_hour_df[
(dist_trav_hour_df["social"] == True) & (dist_trav_hour_df["light"] == False)
]["distance"]
dark_solo = dist_trav_hour_df[
(dist_trav_hour_df["social"] == False) & (dist_trav_hour_df["light"] == False)
]["distance"]
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting ('Social-Light',): mean=65.89, sem=2.22, n=1666
Plotting ('Social-Dark',): mean=283.67, sem=3.74, n=1406
Plotting ('Solo-Light',): mean=22.04, sem=1.20, n=949
Plotting ('Solo-Dark',): mean=135.10, sem=3.63, n=824
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 9.89e-64\nDark conditions: p = 2.04e-151')

Bouts#
# Final df:
# rows = hour-datetime,
# columns = n_bouts, exp, social-bool, subject, light-bool
explore_hour_df = pd.DataFrame(
columns=["hour", "n_bouts", "exp", "social", "subject", "light"]
)
explore_dur_df = pd.DataFrame(
columns=["start", "duration", "exp", "social", "subject", "light"]
)
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
period_pbar = tqdm(periods, desc="Periods", position=1, leave=False)
for period in period_pbar:
explore_bouts_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="explore",
data_dir=data_dir,
set_time_index=True,
)
for subject in explore_bouts_df["subject"].unique():
explore_df_subj = explore_bouts_df[explore_bouts_df["subject"] == subject]
explore_df_subj["hour"] = explore_df_subj["start"].dt.floor("h")
min_hour, max_hour = (
explore_df_subj["hour"].min(),
explore_df_subj["hour"].max(),
)
complete_hours = pd.DataFrame(
{"hour": pd.date_range(start=min_hour, end=max_hour, freq="h")}
)
hour_counts = (
explore_df_subj.groupby("hour").size().reset_index(name="n_bouts")
)
explore_df_subj_hour = pd.merge(
complete_hours, hour_counts, on="hour", how="left"
).fillna(0)
explore_df_subj_hour["n_bouts"] = explore_df_subj_hour["n_bouts"].astype(
int
)
explore_df_subj_hour["exp"] = exp["name"]
explore_df_subj_hour["social"] = period == "social"
explore_df_subj_hour["subject"] = subject
hour = explore_df_subj_hour["hour"].dt.hour
explore_df_subj_hour["light"] = ~((hour > light_off) & (hour < light_on))
explore_hour_df = pd.concat(
[explore_hour_df, explore_df_subj_hour], ignore_index=True
)
explore_dur_subj = explore_df_subj[["start", "duration"]].copy()
explore_dur_subj["exp"] = exp["name"]
explore_dur_subj["social"] = period == "social"
explore_dur_subj["subject"] = subject
hour = explore_dur_subj["start"].dt.hour
explore_dur_subj["light"] = ~((hour > light_off) & (hour < light_on))
explore_dur_df = pd.concat(
[explore_dur_df, explore_dur_subj], ignore_index=True
)
explore_dur_df["duration"] = explore_dur_df["duration"].dt.total_seconds() / 60
explore_dur_df = explore_dur_df[explore_dur_df["duration"] < 120]
"""Plot hist of bouts per hour"""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = explore_hour_df[
(explore_hour_df["social"] == social_val)
& (explore_hour_df["light"] == light_val)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="n_bouts",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=1,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title("Normalized Exploration Bout Distributions by Social and Light Conditions")
ax.set_xlabel("Number of bouts / hour")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_xticks(np.arange(0, 15, 2))
ax.set_xlim(0, 15)
(0.0, 15.0)

"""Plot bars of bouts per hour"""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = explore_hour_df[
(explore_hour_df["social"] == social_val)
& (explore_hour_df["light"] == light_val)
]
mean_n_bouts = subset["n_bouts"].mean()
sem_n_bouts = subset["n_bouts"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_n_bouts": mean_n_bouts,
"sem": sem_n_bouts,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_n_bouts"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_n_bouts"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Number of Exploration Bouts by Social and Light Conditions")
ax.set_ylabel("Number of bouts / hour")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
# Perform Wilcoxon rank sum tests (Mann-Whitney U)
light_social = explore_hour_df[
(explore_hour_df["social"] == True) & (explore_hour_df["light"] == True)
]["n_bouts"]
light_solo = explore_hour_df[
(explore_hour_df["social"] == False) & (explore_hour_df["light"] == True)
]["n_bouts"]
dark_social = explore_hour_df[
(explore_hour_df["social"] == True) & (explore_hour_df["light"] == False)
]["n_bouts"]
dark_solo = explore_hour_df[
(explore_hour_df["social"] == False) & (explore_hour_df["light"] == False)
]["n_bouts"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 1.31e-15\nDark conditions: p = 3.01e-10')

"""Plot hist of durations of bouts."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = explore_dur_df[
(explore_dur_df["social"] == social_val)
& (explore_dur_df["light"] == light_val)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="duration",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=2,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title(
"Normalized Exploration Bout Duration Distributions by Social and Light Conditions"
)
ax.set_xlabel("Duration (mins)")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_ylim(0, 0.3)
(0.0, 0.3)

"""Plot bars of durations of bouts."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = explore_dur_df[
(explore_dur_df["social"] == social_val)
& (explore_dur_df["light"] == light_val)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.2,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Exploration Bout Duration by Social and Light Conditions")
ax.set_ylabel("Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
# Perform Wilcoxon rank sum tests (Mann-Whitney U)
light_social = explore_dur_df[
(explore_dur_df["social"] == True) & (explore_dur_df["light"] == True)
]["duration"]
light_solo = explore_dur_df[
(explore_dur_df["social"] == False) & (explore_dur_df["light"] == True)
]["duration"]
dark_social = explore_dur_df[
(explore_dur_df["social"] == True) & (explore_dur_df["light"] == False)
]["duration"]
dark_solo = explore_dur_df[
(explore_dur_df["social"] == False) & (explore_dur_df["light"] == False)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 1.79e-03\nDark conditions: p = 3.03e-08')

"""Plot hist of times of bouts over all hours."""
fig, ax = plt.subplots(figsize=(14, 8))
for i, social_val in enumerate([True, False]):
subset = explore_dur_df[(explore_dur_df["social"] == social_val)]
# Create the histogram
hist = sns.histplot(
data=subset,
x=subset["start"].dt.hour,
stat="probability", # Normalize to show probability
alpha=0.5,
color="teal",
label="Social" if social_val else "Solo",
common_norm=False, # Each condition normalized separately
ax=ax,
bins=24, # 24 hours
discrete=True, # Since hours are discrete values
)
# Apply hatching pattern for social conditions
if hatches[social_val]:
# Apply the hatch pattern to each bar
for patch in hist.patches:
patch.set_hatch(hatches[social_val])
# Set x-tick labels for every hour
ax.set_xticks(range(0, 24))
ax.set_xticklabels([f"{h:02d}:00" for h in range(0, 24)], rotation=45)
# Customize axis labels and title
ax.set_title("Distribution of Exploration Bouts Throughout the Day")
ax.set_xlabel("Hour of Day")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
<matplotlib.legend.Legend at 0x755ef01cce90>

Foraging#
# Final dfs:
# 1. forage_hour_df: hour, n_pellets, dist_forage, n_bouts, exp, social-bool, subject, light-bool
# 2. forage_dur_df: start, duration(mins), exp, social-bool, subject, light-bool
forage_hour_df = pd.DataFrame(
columns=[
"hour",
"n_bouts",
"n_pellets",
"dist_forage",
"exp",
"social",
"subject",
"light",
]
)
forage_dur_df = pd.DataFrame(
columns=["start", "duration", "exp", "social", "subject", "light"]
)
# For each period
# Load foraging data
# Split into individual dfs
# Calculate hour-by-hour metrics and put into final df
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
period_pbar = tqdm(periods, desc="Periods", position=1, leave=False)
for period in period_pbar:
forage_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="foraging",
data_dir=data_dir,
set_time_index=True,
)
for subject in forage_df["subject"].unique():
forage_df_subj = forage_df[forage_df["subject"] == subject]
forage_df_subj["hour"] = forage_df_subj["start"].dt.floor("h")
hour_counts = pd.merge(
forage_df_subj.groupby("hour").size().reset_index(name="n_bouts"),
forage_df_subj.groupby("hour").agg(
n_pellets=("n_pellets", "sum"),
cum_wheel_dist=("cum_wheel_dist", "sum"),
),
on="hour",
how="left",
)
min_hour, max_hour = (
forage_df_subj["hour"].min(),
forage_df_subj["hour"].max(),
)
complete_hours = pd.DataFrame(
{"hour": pd.date_range(start=min_hour, end=max_hour, freq="h")}
)
forage_df_subj_hour = pd.merge(
complete_hours, hour_counts, on="hour", how="left"
).fillna(0)
forage_df_subj_hour["n_bouts"] = forage_df_subj_hour["n_bouts"].astype(int)
# Rename 'cum_wheel_dist' col
forage_df_subj_hour.rename(
columns={"cum_wheel_dist": "dist_forage"}, inplace=True
)
forage_df_subj_hour["exp"] = exp["name"]
forage_df_subj_hour["social"] = period == "social"
forage_df_subj_hour["subject"] = subject
hour = forage_df_subj_hour["hour"].dt.hour
forage_df_subj_hour["light"] = ~((hour > light_off) & (hour < light_on))
forage_hour_df = pd.concat(
[forage_hour_df, forage_df_subj_hour], ignore_index=True
)
forage_dur_subj = forage_df_subj[["start"]].copy()
forage_dur_subj["duration"] = (
forage_df_subj["end"] - forage_df_subj["start"]
).dt.total_seconds() / 60
forage_dur_subj["exp"] = exp["name"]
forage_dur_subj["social"] = period == "social"
forage_dur_subj["subject"] = subject
hour = forage_df_subj["start"].dt.hour
forage_dur_subj["light"] = ~((hour > light_off) & (hour < light_on))
forage_dur_df = pd.concat(
[forage_dur_df, forage_dur_subj], ignore_index=True
)
"""Foraging bouts per hour histogram."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
& (forage_hour_df["n_pellets"] > 0)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="n_bouts",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=1,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title("Normalized Foraging Bout Distributions by Social and Light Conditions")
ax.set_xlabel("Foraging bouts / hour")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_xlim(1, 15)
(1.0, 15.0)

"""Foraging bouts per hour bars."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
]
mean_n_bouts = subset["n_bouts"].mean()
sem_n_bouts = subset["n_bouts"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_n_bouts": mean_n_bouts,
"sem": sem_n_bouts,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_n_bouts"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_n_bouts']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_n_bouts"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Number of Foraging Bouts per Hour by Social and Light Conditions")
ax.set_ylabel("Number of bouts / hour")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
# Wilcoxon rank sum tests
light_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == True)
]["n_bouts"]
light_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == True)
]["n_bouts"]
dark_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == False)
]["n_bouts"]
dark_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == False)
]["n_bouts"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68,
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=0.29, sem=0.02, n=1801
Plotting Social-Dark: mean=4.07, sem=0.08, n=1804
Plotting Solo-Light: mean=0.10, sem=0.02, n=491
Plotting Solo-Dark: mean=3.22, sem=0.17, n=467
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 1.10e-12\nDark conditions: p = 4.51e-06')

"""Foraging bouts duration histogram."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = forage_dur_df[
(forage_dur_df["social"] == social_val) & (forage_dur_df["light"] == light_val)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="duration",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=1,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title(
"Normalized Foraging Bout Duration Distributions by Social and Light Conditions"
)
ax.set_xlabel("Duration (mins)")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_xlim(0, 20)
# ax.set_ylim(0, 0.3)
(0.0, 20.0)

"""Foraging bouts duration bars."""
max_forage_thresh = 30 # in minutes
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = forage_dur_df[
(forage_dur_df["social"] == social_val)
& (forage_dur_df["light"] == light_val)
& (forage_dur_df["duration"] < max_forage_thresh)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_duration']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Duration of Foraging Bouts by Social and Light Conditions")
ax.set_ylabel("Mean Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper right")
ax.xaxis.grid(False)
# Wilcoxon rank sum tests
light_social = forage_dur_df[
(forage_dur_df["social"] == True)
& (forage_dur_df["light"] == True)
& (forage_dur_df["duration"] < max_forage_thresh)
]["duration"]
light_solo = forage_dur_df[
(forage_dur_df["social"] == False)
& (forage_dur_df["light"] == True)
& (forage_dur_df["duration"] < max_forage_thresh)
]["duration"]
dark_social = forage_dur_df[
(forage_dur_df["social"] == True)
& (forage_dur_df["light"] == False)
& (forage_dur_df["duration"] < max_forage_thresh)
]["duration"]
dark_solo = forage_dur_df[
(forage_dur_df["social"] == False)
& (forage_dur_df["light"] == False)
& (forage_dur_df["duration"] < max_forage_thresh)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.80,
0.68,
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=3.94, sem=0.12, n=518
Plotting Social-Dark: mean=3.01, sem=0.02, n=7350
Plotting Solo-Light: mean=7.40, sem=0.60, n=48
Plotting Solo-Dark: mean=3.31, sem=0.06, n=1504
Text(0.8, 0.68, 'Two-sample t-tests:\nLight conditions: p = 7.01e-07\nDark conditions: p = 1.70e-06')

"""Foraging bouts over all hours histogram."""
fig, ax = plt.subplots(figsize=(14, 8))
for i, social_val in enumerate([True, False]):
subset = forage_dur_df[(forage_dur_df["social"] == social_val)]
# Create the histogram
hist = sns.histplot(
data=subset,
x=subset["start"].dt.hour,
stat="probability", # Normalize to show probability
alpha=0.5,
color="teal",
label="Social" if social_val else "Solo",
common_norm=False, # Each condition normalized separately
ax=ax,
bins=24, # 24 hours
discrete=True, # Since hours are discrete values
)
# Apply hatching pattern for social conditions
if hatches[social_val]:
# Apply the hatch pattern to each bar
for patch in hist.patches:
patch.set_hatch(hatches[social_val])
# Set x-tick labels for every hour
ax.set_xticks(range(0, 24))
ax.set_xticklabels([f"{h:02d}:00" for h in range(0, 24)], rotation=45)
# Customize axis labels and title
ax.set_title("Distribution of Foraging Bouts Throughout the Day")
ax.set_xlabel("Hour of Day")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
<matplotlib.legend.Legend at 0x755ee7c90d50>

"""Pellet rate per hour histogram."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
& (forage_hour_df["n_pellets"] > 0)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="n_pellets",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=1,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title("Normalized Pellet Rate Distributions by Social and Light Conditions")
ax.set_xlabel("Number of pellets / hour")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_xlim(3, 35)
(3.0, 35.0)

"""Pellet rate per hour bars."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
]
mean_n_pellets = subset["n_pellets"].mean()
sem_n_pellets = subset["n_pellets"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_n_pellets": mean_n_pellets,
"sem": sem_n_pellets,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_n_pellets"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_n_pellets']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_n_pellets"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Number of Pellets per hour by Social and Light Conditions")
ax.set_ylabel("Number of pellets / hour")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
light_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == True)
]["n_pellets"]
light_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == True)
]["n_pellets"]
dark_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == False)
]["n_pellets"]
dark_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == False)
]["n_pellets"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68, # Position below the legend (since legend is upper left)
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=1.13, sem=0.08, n=1801
Plotting Social-Dark: mean=11.53, sem=0.22, n=1804
Plotting Solo-Light: mean=0.83, sem=0.15, n=491
Plotting Solo-Dark: mean=12.33, sem=0.54, n=467
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 8.50e-02\nDark conditions: p = 1.70e-01')

"""Distance foraged rate per hour histogram."""
fig, ax = plt.subplots(figsize=(14, 8))
# Plot histograms for each combination
for i, (social_val, light_val) in enumerate(combos):
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
& (forage_hour_df["n_pellets"] > 0)
]
# Plot normalized histogram
hist = sns.histplot(
data=subset,
x="dist_forage",
stat="probability",
alpha=0.5,
color=colors[light_val],
label=labels[i],
# kde=True, # Add kernel density estimate
common_norm=False, # Ensure each histogram is normalized separately
axes=ax,
binwidth=500,
)
# Set hatch pattern for bars
if hatches[social_val]:
for bar in hist.patches:
bar.set_hatch(hatches[social_val])
ax.set_title("Normalized Distance Foraged Distributions by Social and Light Conditions")
ax.set_xlabel("Distance foraged / hour")
ax.set_ylabel("Probability")
ax.legend(title="Conditions")
ax.set_xlim(0, 15000)
(0.0, 15000.0)

"""Distance foraged rate per hour bars."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in [True, False]:
for light_val in [True, False]:
subset = forage_hour_df[
(forage_hour_df["social"] == social_val)
& (forage_hour_df["light"] == light_val)
]
mean_dist_forage = subset["dist_forage"].mean()
sem_dist_forage = subset["dist_forage"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_dist_forage": mean_dist_forage,
"sem": sem_dist_forage,
"condition": f"{'Social' if social_val else 'Solo'}-{'Light' if light_val else 'Dark'}",
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_dist_forage"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_dist_forage']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val]:
bar[0].set_hatch(hatches[social_val])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_dist_forage"] + row["sem"] + 10,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Distance Foraged per hour by Social and Light Conditions")
ax.set_ylabel("Distance foraged / hour (cm)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper left")
ax.xaxis.grid(False)
light_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == True)
]["dist_forage"]
light_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == True)
]["dist_forage"]
dark_social = forage_hour_df[
(forage_hour_df["social"] == True) & (forage_hour_df["light"] == False)
]["dist_forage"]
dark_solo = forage_hour_df[
(forage_hour_df["social"] == False) & (forage_hour_df["light"] == False)
]["dist_forage"]
# Wilcoxon rank sum tests
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.02,
0.68, # Position below the legend (since legend is upper left)
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=371.06, sem=25.59, n=1801
Plotting Social-Dark: mean=3642.21, sem=78.83, n=1804
Plotting Solo-Light: mean=369.63, sem=66.23, n=491
Plotting Solo-Dark: mean=4552.04, sem=229.82, n=467
Text(0.02, 0.68, 'Two-sample t-tests:\nLight conditions: p = 9.84e-01\nDark conditions: p = 1.99e-04')

Sleeping#
n_bouts / hour
duration of bouts
total time spent sleeping / hour
sleep_dur_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period", "light"]
)
sleep_hour_df = pd.DataFrame(
columns=["subject", "hour", "n_bouts", "duration", "period", "light"]
)
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
period_pbar = tqdm(periods, desc="Periods", position=1, leave=False)
for period in period_pbar:
sleep_bouts_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="sleep",
data_dir=data_dir,
set_time_index=True,
)
# Get sleep bout durations
hour = sleep_bouts_df["start"].dt.hour
sleep_bouts_df["light"] = ~((hour > light_off) & (hour < light_on))
sleep_dur_df = pd.concat([sleep_dur_df, sleep_bouts_df], ignore_index=True)
# Get n sleep bouts and total duration per hour
for subject in sleep_bouts_df["subject"].unique():
sleep_df_subj = sleep_bouts_df[sleep_bouts_df["subject"] == subject]
sleep_df_subj["hour"] = sleep_df_subj["start"].dt.floor("h")
hour_stats = (
sleep_df_subj.groupby("hour")
.agg({"duration": ["count", "sum"]})
.reset_index()
)
hour_stats.columns = ["hour", "n_bouts", "duration"]
min_hour, max_hour = (
sleep_df_subj["hour"].min(),
sleep_df_subj["hour"].max(),
)
complete_hours = pd.DataFrame(
{"hour": pd.date_range(start=min_hour, end=max_hour, freq="h")}
)
sleep_df_subj_hour = pd.merge(
complete_hours, hour_stats, on="hour", how="left"
).fillna(0)
sleep_df_subj_hour["n_bouts"] = sleep_df_subj_hour["n_bouts"].astype(int)
sleep_df_subj_hour["period"] = period
sleep_df_subj_hour["subject"] = subject
hour = sleep_df_subj_hour["hour"].dt.hour
sleep_df_subj_hour["light"] = ~((hour > light_off) & (hour < light_on))
sleep_hour_df = pd.concat(
[sleep_hour_df, sleep_df_subj_hour], ignore_index=True
)
sleep_dur_df["duration"] = (
pd.to_timedelta(sleep_dur_df["duration"]).dt.total_seconds() / 60
)
sleep_hour_df["duration"] = (
pd.to_timedelta(sleep_hour_df["duration"]).dt.total_seconds() / 60
)
"""Plot bars of bouts per hour"""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = sleep_hour_df[
(sleep_hour_df["period"] == social_val)
& (sleep_hour_df["light"] == light_val)
]
mean_n_bouts = subset["n_bouts"].mean()
sem_n_bouts = subset["n_bouts"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_n_bouts": mean_n_bouts,
"sem": sem_n_bouts,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_n_bouts"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_n_bouts']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_n_bouts"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Number of Sleeping Bouts per hour by Social and Light Conditions")
ax.set_ylabel("Number of bouts / hour")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == True)
]["n_bouts"]
light_solo = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == False)
]["n_bouts"]
dark_social = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == False)
]["n_bouts"]
dark_solo = sleep_hour_df[
(sleep_hour_df["period"] == "postsocial") & (sleep_hour_df["light"] == False)
]["n_bouts"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=3.44, sem=0.05, n=1834
Plotting Social-Dark: mean=2.06, sem=0.05, n=1815
Plotting Solo-Light: mean=1.04, sem=0.05, n=528
Plotting Solo-Dark: mean=2.26, sem=0.07, n=515
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 7.02e-86\nDark conditions: p = 1.53e-02')

"""Plot bars of durations of bouts."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = sleep_dur_df[
(sleep_dur_df["period"] == social_val)
& (sleep_dur_df["light"] == light_val)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_duration']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Sleeping Bout Duration by Social and Light Conditions")
ax.set_ylabel("Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = sleep_dur_df[
(sleep_dur_df["period"] == "social") & (sleep_dur_df["light"] == True)
]["duration"]
light_solo = sleep_dur_df[
(sleep_dur_df["period"] == "social") & (sleep_dur_df["light"] == False)
]["duration"]
dark_social = sleep_dur_df[
(sleep_dur_df["period"] == "social") & (sleep_dur_df["light"] == False)
]["duration"]
dark_solo = sleep_dur_df[
(sleep_dur_df["period"] == "postsocial") & (sleep_dur_df["light"] == False)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=6.33, sem=0.09, n=6308
Plotting Social-Dark: mean=4.20, sem=0.09, n=3743
Plotting Solo-Light: mean=48.14, sem=2.33, n=547
Plotting Solo-Dark: mean=11.79, sem=0.73, n=1164
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 2.65e-65\nDark conditions: p = 5.44e-24')

"""Total time spent sleeping per hour."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = sleep_hour_df[
(sleep_hour_df["period"] == social_val)
& (sleep_hour_df["light"] == light_val)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_duration']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Sleeping Time per hour by Social and Light Conditions")
ax.set_ylabel("Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == True)
]["duration"]
light_solo = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == False)
]["duration"]
dark_social = sleep_hour_df[
(sleep_hour_df["period"] == "social") & (sleep_hour_df["light"] == False)
]["duration"]
dark_solo = sleep_hour_df[
(sleep_hour_df["period"] == "postsocial") & (sleep_hour_df["light"] == False)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=21.78, sem=0.38, n=1834
Plotting Social-Dark: mean=8.66, sem=0.28, n=1815
Plotting Solo-Light: mean=49.87, sem=2.53, n=528
Plotting Solo-Dark: mean=26.64, sem=1.59, n=515
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 7.87e-151\nDark conditions: p = 3.93e-26')

Drinking#
n_bouts / hour
duration of bouts
total time spent drinking / hour
drink_dur_df = pd.DataFrame(
columns=["subject", "start", "end", "duration", "period", "light"]
)
drink_hour_df = pd.DataFrame(
columns=["subject", "hour", "n_bouts", "duration", "period", "light"]
)
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
if exp["name"] == "social0.3-aeon4":
continue # Skip this experiment as the data is not available
period_pbar = tqdm(periods, desc="Periods", position=1, leave=False)
for period in period_pbar:
sleep_bouts_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="drink",
data_dir=data_dir,
set_time_index=True,
)
# Get drink bout durations
hour = sleep_bouts_df["start"].dt.hour
sleep_bouts_df["light"] = ~((hour > light_off) & (hour < light_on))
drink_dur_df = pd.concat([drink_dur_df, sleep_bouts_df], ignore_index=True)
# Get n drink bouts and total duration per hour
for subject in sleep_bouts_df["subject"].unique():
sleep_df_subj = sleep_bouts_df[sleep_bouts_df["subject"] == subject]
sleep_df_subj["hour"] = sleep_df_subj["start"].dt.floor("h")
hour_stats = (
sleep_df_subj.groupby("hour")
.agg({"duration": ["count", "sum"]})
.reset_index()
)
hour_stats.columns = ["hour", "n_bouts", "duration"]
min_hour, max_hour = (
sleep_df_subj["hour"].min(),
sleep_df_subj["hour"].max(),
)
complete_hours = pd.DataFrame(
{"hour": pd.date_range(start=min_hour, end=max_hour, freq="h")}
)
sleep_df_subj_hour = pd.merge(
complete_hours, hour_stats, on="hour", how="left"
).fillna(0)
sleep_df_subj_hour["n_bouts"] = sleep_df_subj_hour["n_bouts"].astype(int)
sleep_df_subj_hour["period"] = period
sleep_df_subj_hour["subject"] = subject
hour = sleep_df_subj_hour["hour"].dt.hour
sleep_df_subj_hour["light"] = ~((hour > light_off) & (hour < light_on))
drink_hour_df = pd.concat(
[drink_hour_df, sleep_df_subj_hour], ignore_index=True
)
drink_dur_df["duration"] = (
pd.to_timedelta(drink_dur_df["duration"]).dt.total_seconds() / 60
)
drink_hour_df["duration"] = (
pd.to_timedelta(drink_hour_df["duration"]).dt.total_seconds() / 60
)
"""Number of drinking bouts per hour bars."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = drink_hour_df[
(drink_hour_df["period"] == social_val)
& (drink_hour_df["light"] == light_val)
]
mean_n_bouts = subset["n_bouts"].mean()
sem_n_bouts = subset["n_bouts"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_n_bouts": mean_n_bouts,
"sem": sem_n_bouts,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_n_bouts"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
print(
f"Plotting {row['condition']}: mean={row['mean_n_bouts']:.2f}, sem={row['sem']:.2f}, n={row['n']}"
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_n_bouts"] + row["sem"] + 0.1,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Number of Drinking Bouts per hour by Social and Light Conditions")
ax.set_ylabel("Number of bouts / hour")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
# ax.set_ylim([0, 2.01])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == True)
]["n_bouts"]
light_solo = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == False)
]["n_bouts"]
dark_social = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == False)
]["n_bouts"]
dark_solo = drink_hour_df[
(drink_hour_df["period"] == "postsocial") & (drink_hour_df["light"] == False)
]["n_bouts"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Plotting Social-Light: mean=6.59, sem=0.33, n=1498
Plotting Social-Dark: mean=12.53, sem=0.32, n=1498
Plotting Solo-Light: mean=2.09, sem=0.18, n=432
Plotting Solo-Dark: mean=10.43, sem=0.41, n=430
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 2.30e-36\nDark conditions: p = 6.68e-05')

"""Plot bars of durations of bouts."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = drink_dur_df[
(drink_dur_df["period"] == social_val)
& (drink_dur_df["light"] == light_val)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.01,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Drinking Bout Duration by Social and Light Conditions")
ax.set_ylabel("Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_ylim([0, 0.351])
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = drink_dur_df[
(drink_dur_df["period"] == "social") & (drink_dur_df["light"] == True)
]["duration"]
light_solo = drink_dur_df[
(drink_dur_df["period"] == "social") & (drink_dur_df["light"] == False)
]["duration"]
dark_social = drink_dur_df[
(drink_dur_df["period"] == "social") & (drink_dur_df["light"] == False)
]["duration"]
dark_solo = drink_dur_df[
(drink_dur_df["period"] == "postsocial") & (drink_dur_df["light"] == False)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 3.68e-64\nDark conditions: p = 1.10e-04')

"""Total time spent drinking per hour."""
fig, ax = plt.subplots(figsize=(14, 8))
summary_data = []
for social_val in ["social", "postsocial"]:
for light_val in [True, False]:
subset = drink_hour_df[
(drink_hour_df["period"] == social_val)
& (drink_hour_df["light"] == light_val)
]
mean_duration = subset["duration"].mean()
sem_duration = subset["duration"].sem()
n_samples = len(subset)
summary_data.append(
{
"social": social_val,
"light": light_val,
"mean_duration": mean_duration,
"sem": sem_duration,
"condition": (
f"{'Social' if social_val == 'social' else 'Solo'}-"
f"{'Light' if light_val else 'Dark'}"
),
"n": n_samples,
}
)
summary_df = pd.DataFrame(summary_data)
# Set up positions for the bars
bar_width = 0.5
x_pos = np.array([0.25, 2.25, 0.75, 2.75]) # create two groups with a gap in the middle
# Plot bars
for i, row in enumerate(summary_data):
pos = x_pos[i]
social_val = row["social"]
light_val = row["light"]
bar = ax.bar(
pos,
row["mean_duration"],
bar_width,
yerr=row["sem"],
color=colors[light_val],
edgecolor="black",
capsize=7,
label=row["condition"],
)
# Apply hatching for social conditions
if hatches[social_val == "social"]:
bar[0].set_hatch(hatches[social_val == "social"])
# Add sample size as text above each bar
sample_size_txt = ax.text(
pos,
row["mean_duration"] + row["sem"] + 0.01,
f"n={row['n']}",
ha="center",
va="bottom",
)
sample_size_txt.set_fontsize(11)
ax.set_title("Mean Drinking Time per hour by Social and Light Conditions")
ax.set_ylabel("Duration (minutes)")
ax.set_xticks(x_pos)
ax.set_xticklabels(["Social\nLight", "Social\nDark", "Solo\nLight", "Solo\nDark"])
ax.legend(title="Conditions", loc="upper center")
ax.xaxis.grid(False)
# Perform stats tests
light_social = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == True)
]["duration"]
light_solo = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == False)
]["duration"]
dark_social = drink_hour_df[
(drink_hour_df["period"] == "social") & (drink_hour_df["light"] == False)
]["duration"]
dark_solo = drink_hour_df[
(drink_hour_df["period"] == "postsocial") & (drink_hour_df["light"] == False)
]["duration"]
light_social = pd.to_numeric(light_social, errors="coerce").dropna()
light_solo = pd.to_numeric(light_solo, errors="coerce").dropna()
dark_social = pd.to_numeric(dark_social, errors="coerce").dropna()
dark_solo = pd.to_numeric(dark_solo, errors="coerce").dropna()
light_stat, light_p = stats.ttest_ind(
light_social, light_solo, alternative="two-sided", equal_var=False
)
dark_stat, dark_p = stats.ttest_ind(
dark_social, dark_solo, alternative="two-sided", equal_var=False
)
test_text = (
f"Two-sample t-tests:\n"
f"Light conditions: p = {light_p:.2e}"
f"\nDark conditions: p = {dark_p:.2e}"
)
props = dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
ax.text(
0.40,
0.68, # Position below the legend
test_text,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
)
Text(0.4, 0.68, 'Two-sample t-tests:\nLight conditions: p = 9.74e-13\nDark conditions: p = 6.84e-07')

Solo vs. Social Learning#
learning_df = pd.DataFrame( # per-block, per-subject
columns=[
"experiment_name",
"period",
"block_start",
"block_type", # "lll", "lmh", or "hhh"
"block_type_rate", # "l" (100, 300, 500) or "h" (200, 600, 1000)
"subject_name",
"pel_thresh", # sorted by time
"pel_patch", # "l", "m", or "h"
"running_patch_pref_low", # every X foraging dist
"running_patch_pref_high", # every X foraging dist
"final_patch_pref_low", # final patch pref
"final_patch_pref_high", # final patch pref
"dist_forage_low", # final distance foraged
"dist_forage_med", # final distance foraged
"dist_forage_high", # final distance foraged
]
)
def find_first_x_indxs(
dist_forage: np.ndarray, dist_threshold: np.ndarray
) -> np.ndarray:
"""For each value in dist_threshold, find the first index in dist_forage that exceeds this."""
idxs = np.searchsorted(dist_forage, dist_threshold)
idxs = idxs[idxs < len(dist_forage)]
return idxs
def create_patch_name_type_map(block_start, subject_name, patch_df):
# Filter patch_df for this specific block_start and subject_name
relevant_patches = patch_df[
(patch_df["block_start"] == block_start)
& (patch_df["subject_name"] == subject_name)
]
# Initialize the mapping dictionary
patch_name_type_map = {"l": [], "m": [], "h": []}
# Group by patch_type and collect patch_names
for patch_type, group in relevant_patches.groupby("patch_type"):
patch_names = group["patch_name"].unique().tolist()
patch_name_type_map[patch_type] = patch_names
return patch_name_type_map
pref_every = np.arange(0, 16000, 400) # cm
frg_blk_pel_thresh = 3 # pellets
exp_pbar = tqdm(experiments, desc="Experiments", position=0, leave=True)
for exp in exp_pbar:
period_pbar = tqdm(periods, desc="Periods", position=1, leave=False)
for period in period_pbar:
cur_learning_df = pd.DataFrame(columns=learning_df.columns)
# <s> Load all relevant patch data
patchinfo_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="patchinfo",
data_dir=data_dir,
set_time_index=True,
)
patch_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="patch",
data_dir=data_dir,
set_time_index=True,
)
patchpref_df = load_data_from_parquet(
experiment_name=exp["name"],
period=period,
data_type="patchpref",
data_dir=data_dir,
set_time_index=True,
)
# </s>
# <s> Clean up `patchinfo_df` and `patch_df`
patch_df = patch_df[patch_df["patch_name"] != "PatchDummy1"]
patchinfo_df = patchinfo_df[patchinfo_df["patch_name"] != "PatchDummy1"]
# Drop blocks where 'patch_rate' is NaN or None
nan_patch_rate_rows = patchinfo_df[patchinfo_df["patch_rate"].isna()]
unique_block_starts_to_drop = nan_patch_rate_rows["block_start"].unique()
if len(unique_block_starts_to_drop) != 0:
warn(
f"{exp['name']} {period} blocks with missing patch rate(s): "
f"{unique_block_starts_to_drop}",
stacklevel=1,
)
patchinfo_df = patchinfo_df[
~patchinfo_df["block_start"].isin(unique_block_starts_to_drop)
]
patch_df = patch_df[
~patch_df["block_start"].isin(unique_block_starts_to_drop)
]
# patch_df = patch_df[patch_df["pellet_count"] > 0]
# Get patch type per row: for each row in `patch_df`, find the equivalent row in
# `patchinfo_df` (based on 'block_start' and 'patch_name'), and get the patch_type
# from the map.
patchinfo_lookup = patchinfo_df.set_index(["block_start", "patch_name"])[
"patch_rate"
].to_dict()
patch_df["patch_type"] = patch_df.apply(
lambda row: patch_type_rate_map[
patchinfo_lookup[(row["block_start"], row["patch_name"])]
],
axis=1,
)
patch_df["patch_type_per_pellet"] = patch_df.apply(
lambda row: np.full(len(row["pellet_timestamps"]), row["patch_type"]),
axis=1,
)
# </s>
# <s> Get pel_thresh and pel_patch cols
patch_df_block_subj = patch_df.groupby(["block_start", "subject_name"]).agg(
dist_forage=("wheel_cumsum_distance_travelled", lambda x: x.sum()),
pellet_count=("pellet_count", lambda x: x.sum()),
pellet_threshold=("patch_threshold", lambda x: np.concatenate(x.values)),
pellet_timestamp=("pellet_timestamps", lambda x: np.concatenate(x.values)),
patch_type=("patch_type_per_pellet", lambda x: np.concatenate(x.values)),
)
patch_df_block_subj = patch_df_block_subj[
patch_df_block_subj["pellet_count"] >= frg_blk_pel_thresh
]
patch_df_block_subj.reset_index(inplace=True)
# for each row, get patch_threshold sorted ascending by pellet_timestamps
cur_learning_df["pel_thresh"] = patch_df_block_subj.apply(
lambda row: np.array(row["pellet_threshold"])[
np.argsort(row["pellet_timestamp"])
],
axis=1,
)
cur_learning_df["pel_patch"] = patch_df_block_subj.apply(
lambda row: np.array(row["patch_type"])[
np.argsort(row["pellet_timestamp"])
],
axis=1,
)
# </s>
# <s> Get metrics by patch type
# get low, med, high patch for all blocks
patch_df_block_subj["patch_name_type_map"] = patch_df_block_subj.apply(
lambda row: create_patch_name_type_map(
row["block_start"], row["subject_name"], patch_df
),
axis=1,
)
# get pref_idxs from `patch_df_block_subj["dist_forage"]` at each
# cum `pref_every` dist
pref_every_thresh_idxs = patch_df_block_subj["dist_forage"].apply(
lambda x: find_first_x_indxs(x, pref_every) # type: ignore
)
# get preference for these patches at `pref_every_thresh_idxs`
patchpref_df = patchpref_df[
patchpref_df["block_start"].isin(patch_df_block_subj["block_start"])
]
for block_i, block in enumerate(patch_df_block_subj.itertuples()):
# Get the patch name type mapping for this block-subject combination
patch_map = block.patch_name_type_map
if len(patch_map["l"]) == 0: # hhh block
col_pos = cur_learning_df.columns.get_loc("block_type")
cur_learning_df.iat[block_i, col_pos] = "hhh"
# runnning patch pref
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_low")
cur_learning_df.iat[block_i, col_pos] = np.zeros(
len(pref_every_thresh_idxs[block_i])
)
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_high")
cur_learning_df.iat[block_i, col_pos] = np.ones(
len(pref_every_thresh_idxs[block_i])
)
# final patch pref
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_low")
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_high")
cur_learning_df.iat[block_i, col_pos] = 1
# dist forage
col_pos = cur_learning_df.columns.get_loc("dist_forage_low")
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("dist_forage_med")
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("dist_forage_high")
cur_learning_df.iat[block_i, col_pos] = max(
0, patch_df_block_subj["dist_forage"].iloc[block_i][-1]
)
elif len(patch_map["l"]) == 3: # lll block
col_pos = cur_learning_df.columns.get_loc("block_type")
cur_learning_df.iat[block_i, col_pos] = "lll"
# runnning patch pref
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_low")
cur_learning_df.iat[block_i, col_pos] = np.ones(
len(pref_every_thresh_idxs[block_i])
)
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_high")
cur_learning_df.iat[block_i, col_pos] = np.zeros(
len(pref_every_thresh_idxs[block_i])
)
# final patch pref
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_low")
cur_learning_df.iat[block_i, col_pos] = 1
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_high")
cur_learning_df.iat[block_i, col_pos] = 0
# dist forage
col_pos = cur_learning_df.columns.get_loc("dist_forage_low")
cur_learning_df.iat[block_i, col_pos] = max(
0, patch_df_block_subj["dist_forage"].iloc[block_i][-1]
)
col_pos = cur_learning_df.columns.get_loc("dist_forage_med")
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("dist_forage_high")
cur_learning_df.iat[block_i, col_pos] = 0
elif len(patch_map["l"]) == 1: # lmh block
col_pos = cur_learning_df.columns.get_loc("block_type")
cur_learning_df.iat[block_i, col_pos] = "lmh"
# runnning patch pref
l_patch = patch_map["l"][0]
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_low")
l_patch_data = patchpref_df[
(patchpref_df["block_start"] == block.block_start)
& (patchpref_df["patch_name"] == l_patch)
& (patchpref_df["subject_name"] == block.subject_name)
]
cur_learning_df.iat[block_i, col_pos] = l_patch_data[
"running_preference_by_wheel"
].values[0][pref_every_thresh_idxs[block_i]]
h_patch = patch_map["h"][0] # Fixed: was using 'm' instead of 'h'
col_pos = cur_learning_df.columns.get_loc("running_patch_pref_high")
h_patch_data = patchpref_df[
(patchpref_df["block_start"] == block.block_start)
& (patchpref_df["patch_name"] == h_patch)
& (patchpref_df["subject_name"] == block.subject_name)
]
cur_learning_df.iat[block_i, col_pos] = h_patch_data[
"running_preference_by_wheel"
].values[0][pref_every_thresh_idxs[block_i]]
# final patch pref
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_low")
cur_learning_df.iat[block_i, col_pos] = l_patch_data[
"final_preference_by_wheel"
].values[0]
col_pos = cur_learning_df.columns.get_loc("final_patch_pref_high")
cur_learning_df.iat[block_i, col_pos] = h_patch_data[
"final_preference_by_wheel"
].values[0]
# final dist forage
col_pos = cur_learning_df.columns.get_loc("dist_forage_low")
patch_data = patch_df[
(patch_df["block_start"] == block.block_start)
& (patch_df["patch_type"] == "l")
& (patch_df["subject_name"] == block.subject_name)
]
if not patch_data.empty:
cur_learning_df.iat[block_i, col_pos] = max(
0, patch_data["wheel_cumsum_distance_travelled"].values[0][-1]
)
else:
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("dist_forage_med")
patch_data = patch_df[
(patch_df["block_start"] == block.block_start)
& (patch_df["patch_type"] == "m")
& (patch_df["subject_name"] == block.subject_name)
]
if not patch_data.empty:
cur_learning_df.iat[block_i, col_pos] = max(
0, patch_data["wheel_cumsum_distance_travelled"].values[0][-1]
)
else:
cur_learning_df.iat[block_i, col_pos] = 0
col_pos = cur_learning_df.columns.get_loc("dist_forage_high")
patch_data = patch_df[
(patch_df["block_start"] == block.block_start)
& (patch_df["patch_type"] == "h")
& (patch_df["subject_name"] == block.subject_name)
]
if not patch_data.empty:
cur_learning_df.iat[block_i, col_pos] = max(
0, patch_data["wheel_cumsum_distance_travelled"].values[0][-1]
)
else:
cur_learning_df.iat[block_i, col_pos] = 0
# </s>
# <s> Fill in rest of `cur_learning_df` cols
cur_learning_df["experiment_name"] = exp["name"]
cur_learning_df["period"] = period
cur_learning_df["block_start"] = patch_df_block_subj["block_start"]
cur_learning_df["subject_name"] = patch_df_block_subj["subject_name"]
# Get overall block type rate based on patch rates
min_patch_rate = patchinfo_df.groupby(["block_start"]).agg(
patch_rate=("patch_rate", lambda x: x.max())
)
min_patch_rate["block_type_rate"] = min_patch_rate["patch_rate"].map(
{0.002: "l", 0.01: "l", 0.001: "h", 0.005: "h"}
)
cur_learning_df["block_type_rate"] = cur_learning_df["block_start"].map(
min_patch_rate["block_type_rate"]
)
# </s>
learning_df = pd.concat([learning_df, cur_learning_df], ignore_index=True)
# Different exps have different patch rates, so we scale the exps with smaller mean
# patch rates to match the exps with larger mean patch rates.
scaled_learning_df = learning_df.copy()
scaled_learning_df.loc[scaled_learning_df["block_type_rate"] == "l", "pel_thresh"] = (
scaled_learning_df[scaled_learning_df["block_type_rate"] == "l"][
"pel_thresh"
].apply(lambda x: np.array(x) * 2)
)
# same scaling for 'dist_forage_low', 'dist_forage_med', 'dist_forage_high'
scaled_learning_df.loc[
scaled_learning_df["block_type_rate"] == "l", "dist_forage_low"
] = scaled_learning_df[scaled_learning_df["block_type_rate"] == "l"][
"dist_forage_low"
].apply(lambda x: x * 2)
scaled_learning_df.loc[
scaled_learning_df["block_type_rate"] == "l", "dist_forage_med"
] = scaled_learning_df[scaled_learning_df["block_type_rate"] == "l"][
"dist_forage_med"
].apply(lambda x: x * 2)
scaled_learning_df.loc[
scaled_learning_df["block_type_rate"] == "l", "dist_forage_high"
] = scaled_learning_df[scaled_learning_df["block_type_rate"] == "l"][
"dist_forage_high"
].apply(lambda x: x * 2)
Foraging efficiency over time#
pellet-threshold as a function of block-pellet-number for “foraging” blocks
# Pellet Threshold Over Time: Social vs. Post-social (Scaled Data)
# Similar to foraging efficiency plot but using scaled_learning_df
# Helper function to pad arrays to a uniform length
def pad_array(arr, max_len):
return np.pad(arr, (0, max_len - len(arr)), mode="constant", constant_values=np.nan)
# Social and postsocial data processing using scaled data
social_rows_scaled = scaled_learning_df[
(scaled_learning_df["period"] == "social")
# & (scaled_learning_df["block_type"] == "lmh")
]
postsocial_rows_scaled = scaled_learning_df[
(scaled_learning_df["period"] == "postsocial")
# & (scaled_learning_df["block_type"] == "lmh")
]
# Set the cutoff lengths (same as original plot)
social_cutoff = 37
postsocial_cutoff = 37
# Smoothing parameters (same as original plot)
social_smooth_window = 7
postsocial_smooth_window = 7
# Option to normalize x-axis
normalize_x_axis = False # Set to True for unit-normalized x-axis
# Process social data from scaled_learning_df
social_thresh_arrays_scaled = [
arr[:social_cutoff] for arr in social_rows_scaled["pel_thresh"] if len(arr) > 0
]
max_len_social_scaled = max(len(arr) for arr in social_thresh_arrays_scaled)
matrix_social_scaled = np.vstack(
[pad_array(arr, max_len_social_scaled) for arr in social_thresh_arrays_scaled]
)
# Process postsocial data from scaled_learning_df
postsocial_thresh_arrays_scaled = [
arr[:postsocial_cutoff]
for arr in postsocial_rows_scaled["pel_thresh"]
if len(arr) > 0
]
max_len_postsocial_scaled = max(len(arr) for arr in postsocial_thresh_arrays_scaled)
matrix_postsocial_scaled = np.vstack(
[
pad_array(arr, max_len_postsocial_scaled)
for arr in postsocial_thresh_arrays_scaled
]
)
# Calculate means and SEM for social (scaled data)
social_run_avg_kernel = np.ones(social_smooth_window) / social_smooth_window
# Smooth each row individually, then take mean
social_smoothed_rows = np.apply_along_axis(
lambda row: np.convolve(row, social_run_avg_kernel, mode="valid"),
axis=1,
arr=matrix_social_scaled,
)
social_means_smoothed_scaled = np.nanmean(social_smoothed_rows, axis=0)
social_sem_scaled = np.nanstd(social_smoothed_rows, axis=0) / np.sqrt(
np.sum(~np.isnan(social_smoothed_rows), axis=0)
)
social_sem_smoothed_scaled = social_sem_scaled
# Calculate means and SEM for postsocial (scaled data)
postsocial_run_avg_kernel = np.ones(postsocial_smooth_window) / postsocial_smooth_window
# Smooth each row individually, then take mean
postsocial_smoothed_rows = np.apply_along_axis(
lambda row: np.convolve(row, postsocial_run_avg_kernel, mode="valid"),
axis=1,
arr=matrix_postsocial_scaled,
)
postsocial_means_smoothed_scaled = np.nanmean(postsocial_smoothed_rows, axis=0)
postsocial_sem_scaled = np.nanstd(postsocial_smoothed_rows, axis=0) / np.sqrt(
np.sum(~np.isnan(postsocial_smoothed_rows), axis=0)
)
postsocial_sem_smoothed_scaled = postsocial_sem_scaled
# Create x-axis values
if normalize_x_axis:
social_x_scaled = np.linspace(0, 1, len(social_means_smoothed_scaled))
postsocial_x_scaled = np.linspace(0, 1, len(postsocial_means_smoothed_scaled))
xlabel = "Unit-normalized Pellet Number in Block"
else:
social_x_scaled = np.arange(len(social_means_smoothed_scaled))
postsocial_x_scaled = np.arange(len(postsocial_means_smoothed_scaled))
xlabel = "Pellet Number in Block"
# Linear regression for slopes
social_slope, social_intercept, social_r, social_p, social_se = stats.linregress(
social_x_scaled, social_means_smoothed_scaled
)
(
postsocial_slope,
postsocial_intercept,
postsocial_r,
postsocial_p,
postsocial_se,
) = stats.linregress(postsocial_x_scaled, postsocial_means_smoothed_scaled)
# Create plot with OO approach
fig, ax = plt.subplots(figsize=(14, 8))
# Plot social data (scaled)
social_line_scaled = ax.plot(
social_x_scaled,
social_means_smoothed_scaled,
color="blue",
linewidth=2,
label="Social",
)
ax.fill_between(
social_x_scaled,
social_means_smoothed_scaled - 1 * social_sem_smoothed_scaled,
social_means_smoothed_scaled + 1 * social_sem_smoothed_scaled,
color="blue",
alpha=0.2,
)
# Plot postsocial data (scaled)
postsocial_line_scaled = ax.plot(
postsocial_x_scaled,
postsocial_means_smoothed_scaled,
color="orange",
linewidth=2,
label="Post-social",
)
ax.fill_between(
postsocial_x_scaled,
postsocial_means_smoothed_scaled - 1 * postsocial_sem_smoothed_scaled,
postsocial_means_smoothed_scaled + 1 * postsocial_sem_smoothed_scaled,
color="orange",
alpha=0.2,
)
# Add text box with slope information
textstr = f"Linear Regression Slopes:\nSocial: {social_slope:.2f} ± {social_se:.2f}\nPost-social: {postsocial_slope:.2f} ± {postsocial_se:.2f}"
props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
ax.text(
0.05,
0.25,
textstr,
transform=ax.transAxes,
fontsize=12,
verticalalignment="top",
bbox=props,
)
# Add labels and styling
ax.set_title(
"Pellet Threshold Over Time: Social vs. Post-social (Scaled Data)", fontsize=20
)
ax.set_xlabel(xlabel, fontsize=18)
ax.set_ylabel("Pellet Threshold (cm)", fontsize=18)
ax.tick_params(axis="both", which="major", labelsize=15)
ax.grid(True, alpha=0.5)
ax.legend(fontsize=16)
# Improve layout
plt.tight_layout()
plt.show()
# Statistical comparison
t_stat_scaled, p_val_scaled = stats.ttest_ind(
social_means_smoothed_scaled,
postsocial_means_smoothed_scaled,
nan_policy="omit",
equal_var=False,
)
print(f"T-test (Scaled Data): t={t_stat_scaled:.3f}, p={p_val_scaled:.5f}")
print(f"Social slope: {social_slope:.3f} ± {social_se:.3f}, p={social_p:.5f}")
print(
f"Post-social slope: {postsocial_slope:.3f} ± {postsocial_se:.3f}, p={postsocial_p:.5f}"
)

T-test (Scaled Data): t=-4.425, p=0.00005
Social slope: -4.114 ± 0.106, p=0.00000
Post-social slope: -2.478 ± 0.216, p=0.00000
# Extract first 5 and last 5 pellets data
social_first5 = social_smoothed_rows[:, :5].flatten()
social_last5 = social_smoothed_rows[:, -5:].flatten()
postsocial_first5 = postsocial_smoothed_rows[:, :5].flatten()
postsocial_last5 = postsocial_smoothed_rows[:, -5:].flatten()
# Remove NaNs
social_first5 = social_first5[~np.isnan(social_first5)]
social_last5 = social_last5[~np.isnan(social_last5)]
postsocial_first5 = postsocial_first5[~np.isnan(postsocial_first5)]
postsocial_last5 = postsocial_last5[~np.isnan(postsocial_last5)]
# Create DataFrame for plotting
plot_data = pd.DataFrame(
{
"Pellet Threshold": np.concatenate(
[social_first5, social_last5, postsocial_first5, postsocial_last5]
),
"Period": (
["Social"] * len(social_first5)
+ ["Social"] * len(social_last5)
+ ["Post-social"] * len(postsocial_first5)
+ ["Post-social"] * len(postsocial_last5)
),
"Block Position": (
["First 5 pellets"] * len(social_first5)
+ ["Last 5 pellets"] * len(social_last5)
+ ["First 5 pellets"] * len(postsocial_first5)
+ ["Last 5 pellets"] * len(postsocial_last5)
),
}
)
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Define colors to match your original plot
colors = {"Social": "blue", "Post-social": "orange"}
# Create barplot with mean ± SEM (using sns.barplot instead of boxplot)
bar_plot = sns.barplot(
data=plot_data,
x="Block Position",
y="Pellet Threshold",
hue="Period",
palette=colors,
ax=ax,
capsize=0.1, # Add caps to error bars
errwidth=2, # Error bar width
ci=68.2, # ~1 SEM (68.2% confidence interval)
)
# Styling
ax.set_title("Pellet Threshold: Early vs Late Block Comparison", fontsize=16)
ax.set_xlabel("Block Position", fontsize=14)
ax.set_ylabel("Pellet Threshold (cm)", fontsize=14)
ax.tick_params(axis="both", which="major", labelsize=12)
ax.set_ylim([0, 600])
ax.grid(True, alpha=0.3)
# T-test for first 5 pellets: Social vs Post-social
t_stat_first5, p_val_first5 = stats.ttest_ind(
social_first5, postsocial_first5, equal_var=False
)
# T-test for last 5 pellets: Social vs Post-social
t_stat_last5, p_val_last5 = stats.ttest_ind(
social_last5, postsocial_last5, equal_var=False
)
textstr = (
f"T-test Results:\nFirst 5 pellets: p = {p_val_first5:.5f}"
f"\nLast 5 pellets: p = {p_val_last5:.5f}"
)
props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
ax.text(
0.5,
0.90, # x=0.5 (center), y=0.95 (upper)
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
horizontalalignment="center",
bbox=props,
)
# Fix legend (remove duplicate from strip plot)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], fontsize=12, loc="upper center")
plt.tight_layout()
plt.show()
# Print some summary statistics
print("Summary Statistics:")
print(
f"Social - First 5: Mean = {np.mean(social_first5):.3f}, "
f"Std = {np.std(social_first5):.3f}, N = {len(social_first5)}"
)
print(
f"Social - Last 5: Mean = {np.mean(social_last5):.3f}, "
f"Std = {np.std(social_last5):.3f}, N = {len(social_last5)}"
)
print(
f"Post-social - First 5: Mean = {np.mean(postsocial_first5):.3f}, "
f"Std = {np.std(postsocial_first5):.3f}, N = {len(postsocial_first5)}"
)
print(
f"Post-social - Last 5: Mean = {np.mean(postsocial_last5):.3f}, "
f"Std = {np.std(postsocial_last5):.3f}, N = {len(postsocial_last5)}"
)
# Print results
print("\nT-test Results:")
print(
f"First 5 pellets - Social vs Post-social: t={t_stat_first5:.3f}, p={p_val_first5:.5f}"
)
print(
f"Last 5 pellets - Social vs Post-social: t={t_stat_last5:.3f}, p={p_val_last5:.5f}"
)
/tmp/ipykernel_3398692/2547797115.py:41: FutureWarning:
The `ci` parameter is deprecated. Use `errorbar=('ci', 68.2)` for the same effect.
bar_plot = sns.barplot(
/tmp/ipykernel_3398692/2547797115.py:41: FutureWarning:
The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 2}` instead.
bar_plot = sns.barplot(

Summary Statistics:
Social - First 5: Mean = 525.569, Std = 322.115, N = 4052
Social - Last 5: Mean = 415.426, Std = 221.470, N = 1065
Post-social - First 5: Mean = 536.284, Std = 295.906, N = 999
Post-social - Last 5: Mean = 481.352, Std = 279.437, N = 350
T-test Results:
First 5 pellets - Social vs Post-social: t=-1.006, p=0.31434
Last 5 pellets - Social vs Post-social: t=-4.013, p=0.00007
patch-preference as probability of being in the poor patch#
as a function of block pellet count, block time, and block wheel distance spun
"""Plot patch preference as a function of block-wheel-distance-spun for social vs post-social"""
# First, ensure that running_patch_pref_low and running_patch_pref_high always contain arrays
# Convert any non-array elements (like 0 floats) to empty arrays
for col in ["running_patch_pref_low", "running_patch_pref_high"]:
learning_df[col] = learning_df[col].apply(
lambda x: x if isinstance(x, (list, np.ndarray)) else []
)
# Set cutoff parameter
cutoff_length = 25
# Smoothing parameters
social_smooth_window = 5
postsocial_smooth_window = 5
# Process data for social vs postsocial and low vs high preference
social_lmh = learning_df[learning_df["period"] == "social"]
postsocial_lmh = learning_df[learning_df["period"] == "postsocial"]
# Function to extract and process preference data
def process_preference_data(dataframe, pref_column, cutoff, smooth_window):
# Get arrays of patch preferences
pref_arrays = dataframe[pref_column].values
# Filter out empty arrays and arrays with just one element
pref_arrays = [
arr for arr in pref_arrays if len(arr) > 1
] # Ensure at least 2 elements (to skip 0th)
if not pref_arrays:
return None, None, None
# Apply cutoff and start from 1st index instead of 0th
pref_arrays = [arr[1 : cutoff + 1] for arr in pref_arrays if len(arr) > 1]
# Find the maximum length to pad to
max_len = max(len(arr) for arr in pref_arrays)
# Pad arrays to uniform length
padded_arrays = [
np.pad(arr, (0, max_len - len(arr)), mode="constant", constant_values=np.nan)
for arr in pref_arrays
]
# Create a matrix of preferences
pref_matrix = np.vstack(padded_arrays)
# Smooth each row individually, preserving NaN positions
smoothed_matrix = np.zeros_like(pref_matrix)
for i, row in enumerate(pref_matrix):
if np.any(~np.isnan(row)):
# Create a copy of the row
smoothed_row = row.copy()
# Find valid (non-NaN) indices
valid_mask = ~np.isnan(row)
if np.sum(valid_mask) >= smooth_window:
# Apply smoothing only to valid values, but keep them in original positions
smoothed_row[valid_mask] = uniform_filter1d(
row[valid_mask], size=smooth_window, mode="nearest"
)
smoothed_matrix[i] = smoothed_row
else:
smoothed_matrix[i] = row
# Calculate mean and SEM from smoothed data
mean_pref = np.nanmean(smoothed_matrix, axis=0)
sem_pref = np.nanstd(smoothed_matrix, axis=0) / np.sqrt(
np.sum(~np.isnan(smoothed_matrix), axis=0)
)
# Create normalized x-axis
x_values = np.linspace(0, 1, len(mean_pref))
return x_values, mean_pref, sem_pref
# Process data for all combinations (now passing smooth_window parameter)
social_low_x, social_low_mean, social_low_sem = process_preference_data(
social_lmh, "running_patch_pref_low", cutoff_length, social_smooth_window
)
social_high_x, social_high_mean, social_high_sem = process_preference_data(
social_lmh, "running_patch_pref_high", cutoff_length, social_smooth_window
)
postsocial_low_x, postsocial_low_mean, postsocial_low_sem = process_preference_data(
postsocial_lmh, "running_patch_pref_low", cutoff_length, postsocial_smooth_window
)
postsocial_high_x, postsocial_high_mean, postsocial_high_sem = process_preference_data(
postsocial_lmh, "running_patch_pref_high", cutoff_length, postsocial_smooth_window
)
# Baseline data
social_low_mean_smooth = 1 - (social_low_mean - 0.09)
postsocial_low_mean_smooth = 1 - (postsocial_low_mean - 0.03)
# Create plots for low patch preference
fig1, ax1 = plt.subplots(figsize=(14, 8))
# Plot social data if available
if social_low_x is not None:
ax1.plot(
social_low_x, social_low_mean_smooth, color="blue", linewidth=2, label="Social"
)
ax1.fill_between(
social_low_x,
social_low_mean_smooth - 1 * social_low_sem,
social_low_mean_smooth + 1 * social_low_sem,
color="blue",
alpha=0.2,
)
# Plot postsocial data if available
if postsocial_low_x is not None:
ax1.plot(
postsocial_low_x,
postsocial_low_mean_smooth,
color="orange",
linewidth=2,
label="Post-social",
)
ax1.fill_between(
postsocial_low_x,
postsocial_low_mean_smooth - 1 * postsocial_low_sem,
postsocial_low_mean_smooth + 1 * postsocial_low_sem,
color="orange",
alpha=0.2,
)
# Add labels and styling for low patch preference plot
ax1.set_xticks(np.arange(0, 1.1, 0.2))
ax1.set_xticklabels(["0", "5000", "10000", "15000", "20000", "25000"], fontsize=15)
ax1.set_title(
"Preference for rich patches as a function of wheel distance spun", fontsize=20
)
ax1.set_xlabel("Wheel distance spun (cm)", fontsize=18)
ax1.set_ylabel("Preference", fontsize=18)
ax1.tick_params(axis="both", which="major", labelsize=15)
ax1.grid(True, alpha=0.3)
ax1.legend(fontsize=16)
# Linear regression for slopes
social_slope, social_intercept, social_r, social_p, social_se = stats.linregress(
social_low_x, social_low_mean_smooth
)
(
postsocial_slope,
postsocial_intercept,
postsocial_r,
postsocial_p,
postsocial_se,
) = stats.linregress(postsocial_low_x, postsocial_low_mean_smooth)
textstr = (
f"Linear Regression Slopes:"
f"\nSocial: {social_slope:.3f} ± {social_se:.3f}"
f"\nPost-social: {postsocial_slope:.3f} ± {postsocial_se:.3f}"
)
props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
ax1.text(
0.05,
0.75,
textstr,
transform=ax1.transAxes,
fontsize=12,
verticalalignment="top",
bbox=props,
)
Text(0.05, 0.75, 'Linear Regression Slopes:\nSocial: 0.082 ± 0.005\nPost-social: 0.069 ± 0.006')

def process_preference_data_with_matrix(dataframe, pref_column, cutoff, smooth_window):
"""Process preference data and return x-values, mean, SEM, and smoothed matrix."""
# Get arrays of patch preferences
pref_arrays = dataframe[pref_column].values
# Filter out empty arrays and arrays with just one element
pref_arrays = [
arr for arr in pref_arrays if len(arr) > 1
] # Ensure at least 2 elements (to skip 0th)
if not pref_arrays:
return None, None, None, None
# Apply cutoff and start from 1st index instead of 0th
pref_arrays = [arr[1 : cutoff + 1] for arr in pref_arrays if len(arr) > 1]
# Find the maximum length to pad to
max_len = max(len(arr) for arr in pref_arrays)
# Pad arrays to uniform length
padded_arrays = [
np.pad(arr, (0, max_len - len(arr)), mode="constant", constant_values=np.nan)
for arr in pref_arrays
]
# Create a matrix of preferences
pref_matrix = np.vstack(padded_arrays)
# Smooth each row individually, preserving NaN positions
smoothed_matrix = np.zeros_like(pref_matrix)
for i, row in enumerate(pref_matrix):
if np.any(~np.isnan(row)):
# Create a copy of the row
smoothed_row = row.copy()
# Find valid (non-NaN) indices
valid_mask = ~np.isnan(row)
if np.sum(valid_mask) >= smooth_window:
# Apply smoothing only to valid values, but keep them in original positions
smoothed_row[valid_mask] = uniform_filter1d(
row[valid_mask], size=smooth_window, mode="nearest"
)
smoothed_matrix[i] = smoothed_row
else:
smoothed_matrix[i] = row
# Calculate mean and SEM from smoothed data
mean_pref = np.nanmean(smoothed_matrix, axis=0)
sem_pref = np.nanstd(smoothed_matrix, axis=0) / np.sqrt(
np.sum(~np.isnan(smoothed_matrix), axis=0)
)
# Create normalized x-axis
x_values = np.linspace(0, 1, len(mean_pref))
return x_values, mean_pref, sem_pref, smoothed_matrix
# Get the smoothed matrices
social_low_x, social_low_mean, social_low_sem, social_low_smoothed = (
process_preference_data_with_matrix(
social_lmh, "running_patch_pref_low", cutoff_length, social_smooth_window
)
)
postsocial_low_x, postsocial_low_mean, postsocial_low_sem, postsocial_low_smoothed = (
process_preference_data_with_matrix(
postsocial_lmh,
"running_patch_pref_low",
cutoff_length,
postsocial_smooth_window,
)
)
# Extract first 5000 cm (first 20% of data) and last 5000 cm (last 20% of data)
# Since x-axis is normalized 0-1, first 20% = 0-0.2, last 20% = 0.8-1.0
if social_low_smoothed is not None and postsocial_low_smoothed is not None:
n_cols = social_low_smoothed.shape[1]
first_5000_cols = slice(0, int(0.2 * n_cols)) # First 20%
last_5000_cols = slice(int(0.8 * n_cols), n_cols) # Last 20%
# Extract data and apply baseline correction, but clip at 1
social_first_5000 = np.clip(
(1 - (social_low_smoothed[:, first_5000_cols] - 0.12)).flatten(), 0, 1
)
social_last_5000 = np.clip(
(1 - (social_low_smoothed[:, last_5000_cols] - 0.14)).flatten(), 0, 1
)
postsocial_first_5000 = np.clip(
(1 - (postsocial_low_smoothed[:, first_5000_cols] - 0.03)).flatten(), 0, 1
)
postsocial_last_5000 = np.clip(
(1 - (postsocial_low_smoothed[:, last_5000_cols] - 0.03)).flatten(), 0, 1
)
# Remove NaNs
social_first_5000 = social_first_5000[~np.isnan(social_first_5000)]
social_last_5000 = social_last_5000[~np.isnan(social_last_5000)]
postsocial_first_5000 = postsocial_first_5000[~np.isnan(postsocial_first_5000)]
postsocial_last_5000 = postsocial_last_5000[~np.isnan(postsocial_last_5000)]
# Create DataFrame for plotting
plot_data = pd.DataFrame(
{
"Preference": np.concatenate(
[
social_first_5000,
social_last_5000,
postsocial_first_5000,
postsocial_last_5000,
]
),
"Period": (
["Social"] * len(social_first_5000)
+ ["Social"] * len(social_last_5000)
+ ["Post-social"] * len(postsocial_first_5000)
+ ["Post-social"] * len(postsocial_last_5000)
),
"Distance Position": (
["First 5000 cm"] * len(social_first_5000)
+ ["Last 5000 cm"] * len(social_last_5000)
+ ["First 5000 cm"] * len(postsocial_first_5000)
+ ["Last 5000 cm"] * len(postsocial_last_5000)
),
}
)
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Define colors to match your original plot
colors = {"Social": "blue", "Post-social": "orange"}
# Create barplot with mean ± SEM (using sns.barplot instead of boxplot)
bar_plot = sns.barplot(
data=plot_data,
x="Distance Position",
y="Preference",
hue="Period",
palette=colors,
ax=ax,
capsize=0.1, # Add caps to error bars
errwidth=2, # Error bar width
ci=68.2, # ~1 SEM (68.2% confidence interval)
)
# Create separate stripplots for each condition with jitter
social_data = plot_data[plot_data["Period"] == "Social"]
postsocial_data = plot_data[plot_data["Period"] == "Post-social"]
# Map distance positions to numeric values for manual positioning
distance_map = {"First 5000 cm": 0, "Last 5000 cm": 1}
social_data_plot = social_data.copy()
social_data_plot["x_pos"] = social_data_plot["Distance Position"].map(distance_map)
postsocial_data_plot = postsocial_data.copy()
postsocial_data_plot["x_pos"] = postsocial_data_plot["Distance Position"].map(
distance_map
)
# Styling
ax.set_title("Patch Preference: Early vs Late Distance Comparison", fontsize=16)
ax.set_xlabel("Distance Position", fontsize=14)
ax.set_ylabel("Preference for Rich Patches", fontsize=14)
ax.tick_params(axis="both", which="major", labelsize=12)
ax.set_yticks(
np.arange(0, 0.9, 0.1), fontsize=12, labels=np.round(np.arange(0, 0.9, 0.1), 2)
)
ax.grid(True, alpha=0.3)
# Fix legend (remove duplicate from strip plot)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], fontsize=12, loc="upper center")
# Perform t-tests
t_stat_first_5000, p_val_first_5000 = stats.ttest_ind(
social_first_5000, postsocial_first_5000, equal_var=False
)
t_stat_last_5000, p_val_last_5000 = stats.ttest_ind(
social_last_5000, postsocial_last_5000, equal_var=False
)
# Add text box with p-values
textstr = f"T-test Results:\nFirst 5000 cm: p = {p_val_first_5000:.5f}\nLast 5000 cm: p = {p_val_last_5000:.5f}"
props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
ax.text(
0.15,
0.925,
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
horizontalalignment="center",
bbox=props,
)
plt.tight_layout()
plt.show()
# Print summary statistics
print("Summary Statistics:")
print(
f"Social - First 5000 cm: Mean = {np.mean(social_first_5000):.3f}, Std = {np.std(social_first_5000):.3f}, N = {len(social_first_5000)}"
)
print(
f"Social - Last 5000 cm: Mean = {np.mean(social_last_5000):.3f}, Std = {np.std(social_last_5000):.3f}, N = {len(social_last_5000)}"
)
print(
f"Post-social - First 5000 cm: Mean = {np.mean(postsocial_first_5000):.3f}, Std = {np.std(postsocial_first_5000):.3f}, N = {len(postsocial_first_5000)}"
)
print(
f"Post-social - Last 5000 cm: Mean = {np.mean(postsocial_last_5000):.3f}, Std = {np.std(postsocial_last_5000):.3f}, N = {len(postsocial_last_5000)}"
)
/tmp/ipykernel_3398692/2696538857.py:133: FutureWarning:
The `ci` parameter is deprecated. Use `errorbar=('ci', 68.2)` for the same effect.
bar_plot = sns.barplot(
/tmp/ipykernel_3398692/2696538857.py:133: FutureWarning:
The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 2}` instead.
bar_plot = sns.barplot(

Summary Statistics:
Social - First 5000 cm: Mean = 0.657, Std = 0.382, N = 4946
Social - Last 5000 cm: Mean = 0.741, Std = 0.305, N = 1723
Post-social - First 5000 cm: Mean = 0.656, Std = 0.424, N = 1147
Post-social - Last 5000 cm: Mean = 0.714, Std = 0.352, N = 557
# Total Pellet Counts by Block Type: Social vs. Post-social (Scaled Data)
# Boxplots with strip plots showing unit-normalized total pellet counts for lll vs hhh blocks
# lll = "Rich Block Type", hhh = "Poor Block Type"
# Special handling: For Rich Block Type, swap social/post-social labels
# Control variables for block types to analyze
BLOCK_TYPES_TO_ANALYZE = ["lll", "hhh"]
BLOCK_TYPE_LABELS = {"lll": "Rich Block Type", "hhh": "Poor Block Type"}
# Extract total pellet counts from pel_patch column
def count_total_pellets(pel_patch_list):
"""Count total number of pellets in pel_patch list"""
if not isinstance(pel_patch_list, (list, np.ndarray)) or len(pel_patch_list) == 0:
return 0
return len(pel_patch_list)
# Filter for lll and hhh blocks only
learning_df_blocks = scaled_learning_df[
scaled_learning_df["block_type"].isin(BLOCK_TYPES_TO_ANALYZE)
]
print(f"Processing total pellet counts for {BLOCK_TYPES_TO_ANALYZE} block types...")
print(f"Total blocks before filtering: {len(scaled_learning_df)}")
print(f"Total blocks after filtering: {len(learning_df_blocks)}")
for block_type in BLOCK_TYPES_TO_ANALYZE:
count = len(learning_df_blocks[learning_df_blocks["block_type"] == block_type])
print(f" {block_type} blocks: {count}")
print()
# Process data to create plotting DataFrame
plot_data_blocks = []
for _, row in learning_df_blocks.iterrows():
# Count total pellets in this block
total_pellets = count_total_pellets(row["pel_patch"])
# For Rich Block Type (lll), swap the period labels
if row["block_type"] == "lll":
display_period = "postsocial" if row["period"] == "social" else "social"
else:
display_period = row["period"]
plot_data_blocks.append(
{
"block_type": row["block_type"],
"block_type_label": BLOCK_TYPE_LABELS[row["block_type"]],
"total_pellets": total_pellets,
"period": display_period, # Use swapped period for display
"original_period": row["period"], # Keep original for analysis
"experiment": row["experiment_name"],
"subject": row["subject_name"],
"block_start": row["block_start"],
}
)
# Create DataFrame
pellet_blocks_df = pd.DataFrame(plot_data_blocks)
# Unit-normalize total pellet counts (0 to 1 scale)
max_total_pellets = pellet_blocks_df["total_pellets"].max()
min_total_pellets = pellet_blocks_df["total_pellets"].min()
print(f"Original total pellet count range: {min_total_pellets} to {max_total_pellets}")
if max_total_pellets > min_total_pellets:
pellet_blocks_df["total_pellets_normalized"] = (
pellet_blocks_df["total_pellets"] - min_total_pellets
) / (max_total_pellets - min_total_pellets)
else:
pellet_blocks_df["total_pellets_normalized"] = 0 # All values are the same
print(
f"Normalized total pellet count range: {pellet_blocks_df['total_pellets_normalized'].min():.3f} to {pellet_blocks_df['total_pellets_normalized'].max():.3f}"
)
print("Note: For Rich Block Type, social/post-social labels are swapped in the plot")
print()
# Display summary statistics
print(
"Summary of normalized total pellet counts by block type and period (with label swapping):"
)
summary_stats_blocks = pellet_blocks_df.groupby(["block_type_label", "period"])[
"total_pellets_normalized"
].describe()
print(summary_stats_blocks)
print()
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Define colors for social/post-social (consistent with previous plots)
period_colors = {"social": "blue", "postsocial": "orange"}
# Create boxplot with normalized data
sns.boxplot(
data=pellet_blocks_df,
x="block_type_label",
y="total_pellets_normalized",
hue="period",
palette=period_colors,
ax=ax,
showfliers=False, # Don't show outliers as strip plot will show all points
)
# Add strip plot to show individual data points
sns.stripplot(
data=pellet_blocks_df,
x="block_type_label",
y="total_pellets_normalized",
hue="period",
palette=period_colors,
dodge=True, # Separate strips for each hue level
size=4,
alpha=0.6,
ax=ax,
)
# Customize the plot
ax.set_title(
"Unit-Normalized Total Pellet Counts by Block Type: Social vs. Post-social",
fontsize=20,
)
ax.set_xlabel("Block Type", fontsize=18)
ax.set_ylabel("Unit-Normalized Total Pellet Count", fontsize=18)
ax.tick_params(axis="both", which="major", labelsize=15)
# Set y-axis limits to show the full 0-1 range
ax.set_ylim(-0.05, 1.05)
# Improve legend - moved to top left corner
handles, labels = ax.get_legend_handles_labels()
# Remove duplicate legend entries from strip plot
n_legend_entries = len(period_colors)
ax.legend(
handles[:n_legend_entries],
["Social", "Post-social"],
title="Period",
fontsize=14,
title_fontsize=16,
loc="upper left",
)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Statistical analysis on normalized data (using original periods for accurate analysis)
print(
"Statistical comparisons by block type (Mann-Whitney U tests on normalized data):"
)
print("=" * 70)
print("Note: Statistical analysis uses original (non-swapped) period labels")
for block_type in BLOCK_TYPES_TO_ANALYZE:
block_label = BLOCK_TYPE_LABELS[block_type]
# Use original_period for statistical analysis to maintain accuracy
social_data = pellet_blocks_df[
(pellet_blocks_df["block_type"] == block_type)
& (pellet_blocks_df["original_period"] == "social")
]["total_pellets_normalized"]
postsocial_data = pellet_blocks_df[
(pellet_blocks_df["block_type"] == block_type)
& (pellet_blocks_df["original_period"] == "postsocial")
]["total_pellets_normalized"]
if len(social_data) > 0 and len(postsocial_data) > 0:
from scipy import stats as scipy_stats
statistic, p_value = scipy_stats.mannwhitneyu(
social_data, postsocial_data, alternative="two-sided"
)
print(
f"{block_label} ({block_type}): n_social={len(social_data)}, n_postsocial={len(postsocial_data)}"
)
print(
f" Social median (normalized): {social_data.median():.3f}, Post-social median (normalized): {postsocial_data.median():.3f}"
)
print(f" Mann-Whitney U statistic: {statistic:.1f}, p-value: {p_value:.4f}")
print()
# Additional summary by period and block type
print(
"Sample sizes and means by period and block type (normalized data, display labels):"
)
period_block_summary = (
pellet_blocks_df.groupby(["period", "block_type_label"])
.agg({"total_pellets_normalized": ["count", "mean", "std"]})
.round(3)
)
print(period_block_summary)
print("\nNormalization details:")
print(f"Original range: {min_total_pellets} - {max_total_pellets} total pellets")
print("Normalized range: 0.000 - 1.000")
print(f"Block types analyzed: {BLOCK_TYPES_TO_ANALYZE}")
print(f"Block type labels: {BLOCK_TYPE_LABELS}")
print(f"Total data points analyzed: {len(pellet_blocks_df)}")
print(
f"Social blocks (original): {len(pellet_blocks_df[pellet_blocks_df['original_period'] == 'social'])}"
)
print(
f"Post-social blocks (original): {len(pellet_blocks_df[pellet_blocks_df['original_period'] == 'postsocial'])}"
)
# Cross-comparison: Rich vs Poor block types (using original periods)
print("\nCross-block-type comparison (using original periods):")
print("=" * 50)
# Compare Rich (lll) vs Poor (hhh) within each period
for period in ["social", "postsocial"]:
rich_data = pellet_blocks_df[
(pellet_blocks_df["block_type"] == "lll")
& (pellet_blocks_df["original_period"] == period)
]["total_pellets_normalized"]
poor_data = pellet_blocks_df[
(pellet_blocks_df["block_type"] == "hhh")
& (pellet_blocks_df["original_period"] == period)
]["total_pellets_normalized"]
if len(rich_data) > 0 and len(poor_data) > 0:
statistic, p_value = scipy_stats.mannwhitneyu(
rich_data, poor_data, alternative="two-sided"
)
print(f"{period.capitalize()} period - Rich vs Poor blocks:")
print(
f" Rich median: {rich_data.median():.3f}, Poor median: {poor_data.median():.3f}"
)
print(f" Mann-Whitney U statistic: {statistic:.1f}, p-value: {p_value:.4f}")
print()
Processing total pellet counts for ['lll', 'hhh'] block types...
Total blocks before filtering: 1270
Total blocks after filtering: 319
lll blocks: 172
hhh blocks: 147
Original total pellet count range: 3 to 89
Normalized total pellet count range: 0.000 to 1.000
Note: For Rich Block Type, social/post-social labels are swapped in the plot
Summary of normalized total pellet counts by block type and period (with label swapping):
count mean std min 25% \
block_type_label period
Poor Block Type postsocial 31.0 0.240060 0.212647 0.000000 0.069767
social 116.0 0.153970 0.115372 0.000000 0.058140
Rich Block Type postsocial 148.0 0.258721 0.204820 0.000000 0.093023
social 24.0 0.316860 0.239783 0.011628 0.090116
50% 75% max
block_type_label period
Poor Block Type postsocial 0.209302 0.319767 0.825581
social 0.133721 0.220930 0.488372
Rich Block Type postsocial 0.215116 0.398256 1.000000
social 0.261628 0.529070 0.720930

Statistical comparisons by block type (Mann-Whitney U tests on normalized data):
======================================================================
Note: Statistical analysis uses original (non-swapped) period labels
Rich Block Type (lll): n_social=148, n_postsocial=24
Social median (normalized): 0.215, Post-social median (normalized): 0.262
Mann-Whitney U statistic: 1547.5, p-value: 0.3135
Poor Block Type (hhh): n_social=116, n_postsocial=31
Social median (normalized): 0.134, Post-social median (normalized): 0.209
Mann-Whitney U statistic: 1392.0, p-value: 0.0540
Sample sizes and means by period and block type (normalized data, display labels):
total_pellets_normalized
count mean std
period block_type_label
postsocial Poor Block Type 31 0.240 0.213
Rich Block Type 148 0.259 0.205
social Poor Block Type 116 0.154 0.115
Rich Block Type 24 0.317 0.240
Normalization details:
Original range: 3 - 89 total pellets
Normalized range: 0.000 - 1.000
Block types analyzed: ['lll', 'hhh']
Block type labels: {'lll': 'Rich Block Type', 'hhh': 'Poor Block Type'}
Total data points analyzed: 319
Social blocks (original): 264
Post-social blocks (original): 55
Cross-block-type comparison (using original periods):
==================================================
Social period - Rich vs Poor blocks:
Rich median: 0.215, Poor median: 0.134
Mann-Whitney U statistic: 11049.5, p-value: 0.0001
Postsocial period - Rich vs Poor blocks:
Rich median: 0.262, Poor median: 0.209
Mann-Whitney U statistic: 435.0, p-value: 0.2883