{ "cells": [ { "cell_type": "markdown", "id": "2cbbd778", "metadata": {}, "source": [ "# Hidden Markov Model analysis of mouse behavioural syntax\n", "\n", "In this example, we fit and apply a Hidden Markov Model (HMM) to infer behavioural syllables from pre-processed mouse tracking data.\n", "\n", "The [sample dataset](../../downloads/hmm_example_mouse_pos.pkl) is a two-hour snippet of a single mouse in a foraging assay. \n", "The mouse was tracked using [SLEAP](sleap:) with key body parts annotated as follows:\n", "\n", ":::{image} ../../images/hmm-example-mouse-body-parts.png\n", ":alt: Mouse body part annotations\n", ":width: 50%\n", ":align: center\n", ":::\n", "\n", "The data includes:\n", "- raw centroid positions (`x`, `y`)\n", "- Kalman-filtered estimates of centroid positions, speed, and acceleration (`smoothed_x`, `smoothed_y`, `smoothed_speed`, `smoothed_acceleration`)\n", "- pairwise distances between key body parts (`head-spine3`, `left_ear-spine3`, `right_ear-spine3`, `spine1-spine3`)\n", "\n", "These pairwise distances were selected based on their contribution to overall variance in body shape (i.e. length and curvature), as determined by applying Singular Value Decomposition (SVD) to a standardised distance matrix. " ] }, { "cell_type": "markdown", "id": "3389234a", "metadata": {}, "source": [ "## Set up environment\n", "\n", "Create and activate a virtual environment named `hmm-example` using [uv](https://docs.astral.sh/uv/getting-started/installation/).\n", "```bash\n", "uv venv hmm-example --python \">=3.11\" \n", "source hmm-example/bin/activate \n", "```\n", "\n", "Install the required [`ssm` package](https://github.com/lindermanlab/ssm) and its dependencies.\n", "```bash\n", "uv pip install setuptools wheel numpy cython && uv pip install --no-build-isolation \"git+https://github.com/lindermanlab/ssm#egg=ssm[plotting]\"\n", "```\n", "\n", "## Import libraries and define helper class" ] }, { "cell_type": "code", "execution_count": null, "id": "7ac920ac", "metadata": {}, "outputs": [], "source": [ "import autograd.numpy as np\n", "import autograd.numpy.random as npr\n", "\n", "npr.seed(42)\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import pickle\n", "\n", "import seaborn as sns\n", "\n", "import ssm" ] }, { "cell_type": "code", "execution_count": 2, "id": "dd87407b", "metadata": {}, "outputs": [], "source": [ "class AeonHMM:\n", " \"\"\"A class for training and analysing Hidden Markov Models (HMM) using the `ssm` library.\"\"\"\n", "\n", " def __init__(self, n_state):\n", " \"\"\"Initialise AeonHMM with the number of hidden states.\"\"\"\n", " self.n_state = n_state # Number of hidden states\n", " self.features = [\n", " \"smoothed_speed\",\n", " \"smoothed_acceleration\",\n", " \"head-spine3\",\n", " \"left_ear-spine3\",\n", " \"right_ear-spine3\",\n", " \"spine1-spine3\",\n", " ] # Expected features in the input data\n", " self.model = None # HMM model instance\n", " self.parameters = None # Sorted model parameters (mean, variance, covariance)\n", " self.transition_mat = None # Sorted transition matrix\n", " self.states = None # Inferred states\n", " self.connectivity_mat = None # Connectivity matrix\n", " self.test_lls = None # Log-likelihoods of the test data\n", " self.train_lls = None # Log-likelihoods of the training data\n", "\n", " def get_connectivity_matrix(self):\n", " \"\"\"Compute the normalised connectivity matrix from the inferred states.\"\"\"\n", " connectivity_mat = np.zeros((self.n_state, self.n_state))\n", " states = self.states\n", " # Count transitions between states\n", " for i in range(len(states) - 1):\n", " if states[i + 1] != states[i]:\n", " connectivity_mat[states[i]][states[i + 1]] += 1\n", " # Normalise to sum to 1\n", " for i in range(self.n_state):\n", " total = np.sum(connectivity_mat[i])\n", " if total > 0:\n", " connectivity_mat[i] /= total\n", "\n", " return connectivity_mat\n", "\n", " def fit_model(self, train_data, num_iters=50):\n", " \"\"\"Fit the HMM model to the training data using the EM algorithm.\"\"\"\n", " fitting_input = np.array(train_data)\n", " self.model = ssm.HMM(\n", " self.n_state, len(fitting_input[0]), observations=\"gaussian\"\n", " )\n", " lls = self.model.fit(\n", " fitting_input, method=\"em\", num_iters=num_iters, init_method=\"kmeans\"\n", " )\n", " self.train_lls = lls\n", "\n", " def infer_states(self, test_data):\n", " \"\"\"Infer states for the test data.\"\"\"\n", " obs = np.array(test_data)\n", " self.test_lls = self.model.log_likelihood(obs)\n", " self.states = self.model.most_likely_states(obs)\n", "\n", " def sort(self, sort_idx):\n", " \"\"\"Sort the model parameters, transition matrix, and inferred states based on the provided indices.\"\"\"\n", " # Sort Gaussian means: shape (n_features, n_state)\n", " parameters_mean_sorted = self.model.observations.params[0][sort_idx].T\n", " # Extract and sort variances: shape (n_features, n_state)\n", " parameters_var = np.zeros((self.n_state, len(self.features)))\n", " for i in range(self.n_state):\n", " for j in range(len(self.features)):\n", " # state i, feature j\n", " parameters_var[i, j] = self.model.observations.params[1][i][j][j]\n", " parameters_var_sorted = parameters_var[sort_idx].T\n", " # Sort covariance matrices: shape (n_state, n_features, n_features)\n", " parameters_covar_sorted = self.model.observations.params[1][sort_idx]\n", " self.parameters = [\n", " parameters_mean_sorted,\n", " parameters_var_sorted,\n", " parameters_covar_sorted,\n", " ]\n", " # Sort transition matrix: shape (n_state, n_state)\n", " self.transition_mat = (\n", " self.model.transitions.transition_matrix[sort_idx].T[sort_idx].T\n", " )\n", " # Compute connectivity matrix\n", " self.connectivity_mat = self.get_connectivity_matrix()\n", " # Reassign state labels to reflect new order\n", " new_values = np.empty_like(self.states)\n", " for i, val in enumerate(sort_idx):\n", " new_values[self.states == val] = i\n", " self.states = new_values" ] }, { "cell_type": "markdown", "id": "e3ab438d", "metadata": {}, "source": [ "## Load sample data\n", "The sample dataset can be downloaded [here](../../downloads/hmm_example_mouse_pos.pkl). \n", "Please change the value of `file_path` to the location where you saved the file." ] }, { "cell_type": "code", "execution_count": 130, "id": "bc7c1d0d", "metadata": {}, "outputs": [], "source": [ "file_path = \"/path/to/hmm_example_mouse_pos.pkl\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "a1faf0c5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | smoothed_speed | \n", "smoothed_acceleration | \n", "head-spine3 | \n", "spine1-spine3 | \n", "left_ear-spine3 | \n", "right_ear-spine3 | \n", "
---|---|---|---|---|---|---|
2024-02-01 07:00:00.080 | \n", "6.765046 | \n", "19.325969 | \n", "17.767353 | \n", "12.744590 | \n", "17.301722 | \n", "15.257318 | \n", "
2024-02-01 07:00:00.180 | \n", "8.032855 | \n", "13.389354 | \n", "17.269546 | \n", "12.521794 | \n", "18.267784 | \n", "15.036750 | \n", "
2024-02-01 07:00:00.280 | \n", "8.410684 | \n", "7.612856 | \n", "18.342884 | \n", "12.943109 | \n", "18.868068 | \n", "15.568200 | \n", "
2024-02-01 07:00:00.380 | \n", "7.863123 | \n", "9.032341 | \n", "18.010769 | \n", "12.811849 | \n", "18.209636 | \n", "14.367178 | \n", "
2024-02-01 07:00:00.480 | \n", "6.422148 | \n", "15.558013 | \n", "18.457386 | \n", "12.783098 | \n", "18.527658 | \n", "15.594391 | \n", "