Last active
November 17, 2025 21:53
-
-
Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass, field | |
| import pprint | |
| from collections.abc import Sequence | |
| from typing import Callable, Union, Literal, Any, Optional | |
| from pathlib import Path | |
| import csv | |
| from scipy.signal import decimate | |
| from tempfile import TemporaryDirectory | |
| import subprocess | |
| import math | |
| import tqdm | |
| import numpy as np | |
| from functools import partial | |
| from itertools import product | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| from scipy.ndimage import gaussian_filter | |
| import json | |
| CAT_MEASURE_TYPE = Literal["occupancy", "motion_index_freezing", "speed_freezing"] | |
| BOUT_MEASURE_TYPE = Literal["motion_index_freezing", "speed_freezing", "motion_index_dart", "speed_dart"] | |
| def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8, | |
| height_inch: int = 6): | |
| if save_fig_root: | |
| save_fig_root.mkdir(parents=True, exist_ok=True) | |
| fig = plt.gcf() | |
| fig.set_size_inches(width_inch, height_inch) | |
| fig.tight_layout() | |
| fig.savefig( | |
| save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight', | |
| dpi=300 | |
| ) | |
| plt.close() | |
| else: | |
| plt.tight_layout() | |
| plt.show() | |
| def get_log_bins(data: np.ndarray, n_bins: int) -> np.ndarray: | |
| bins = np.logspace(np.floor(np.log10(np.min(data))), np.ceil(np.log10(np.max(data))), n_bins) | |
| return bins | |
| @dataclass(eq=False) | |
| class Experiment: | |
| pre_start: float | |
| pre_end: float | |
| trial_start: float | |
| trial_end: float | |
| post_start: float | |
| post_end: float | |
| filename_fmt: str | |
| title_root: str | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| triplet_name: tuple[str, str, str] = "Habituation", "Trial", "Post Trial" | |
| cm_per_px: float | None = None | |
| pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
| pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
| trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
| post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
| chopped_pos_data: dict[str, "SubjectPos"] = field(default_factory=dict, init=False, repr=False) | |
| box_center: tuple[int, int] = field(default=None, init=False) | |
| box_size: tuple[int, int] = field(default=None, init=False) | |
| image_size: tuple[int, int] = field(default=None, init=False) | |
| image_offset: tuple[int, int] = field(default=None, init=False) | |
| # how much we need to offset our box center so our box center aligns with the mean box center of all boxes | |
| enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
| # size of the largest image/canvas we need to fit all experiments, including their offsets | |
| largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
| cue_point: tuple[float | str, float | str] = field(default=None, init=False, repr=False) | |
| # size of the largest box of all the experiments | |
| def get_header_id_columns(self) -> list[str]: | |
| return list(self.metadata.keys()) | |
| def get_id_columns(self) -> list: | |
| return list(self.metadata.values()) | |
| def pos_path_filename(self, data_root: Path) -> Path: | |
| return data_root / self.filename_fmt.format(**self.metadata) | |
| def box_metadata_filename(self, data_root: Path) -> Path: | |
| return data_root / "{date}_{subject}_top_000000000.json".format(**self.metadata) | |
| @classmethod | |
| def read_experiment_from_csv_inventory( | |
| cls, csv_filename: Path | str, metadata: list[str], filename_fmt: str, | |
| title_root_fmt: str = "#{subject} ({date})", **selectors: str | |
| ) -> Optional["Experiment"]: | |
| experiment = None | |
| with open(csv_filename, "r") as fh: | |
| reader = csv.reader(fh) | |
| header = next(reader) | |
| for row in reader: | |
| metadata_values = {m: row[i] for i, m in enumerate(metadata)} | |
| if not all(metadata_values.get(k) == v for k, v in selectors.items()): | |
| continue | |
| if experiment is not None: | |
| raise ValueError(f"Found more than one row matching {selectors}") | |
| times = [] | |
| for t in row[len(metadata):len(metadata) + 6]: | |
| if ":" in t: | |
| min, sec = map(int, t.split(":")) | |
| times.append(min * 60 + sec) | |
| else: | |
| times.append(int(t)) | |
| experiment = cls( | |
| metadata=metadata_values, filename_fmt=filename_fmt, | |
| pre_start=times[0], pre_end=times[1], | |
| trial_start=times[2], trial_end=times[3], | |
| post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values), | |
| ) | |
| if experiment is None: | |
| raise ValueError(f"Didn't find any row matching {selectors}") | |
| return experiment | |
| @classmethod | |
| def read_experiments_from_csv_inventory( | |
| cls, filename: Path, metadata: list[str], filename_fmt: str, | |
| title_root_fmt: str = "#{subject} ({date})", | |
| ) -> list["Experiment"]: | |
| experiments = [] | |
| with open(filename, "r") as fh: | |
| reader = csv.reader(fh) | |
| header = next(reader) | |
| for row in reader: | |
| metadata_values = {m: row[i] for i, m in enumerate(metadata)} | |
| times = [] | |
| for t in row[len(metadata):len(metadata) + 6]: | |
| if ":" in t: | |
| min, sec = map(int, t.split(":")) | |
| times.append(min * 60 + sec) | |
| else: | |
| times.append(int(t)) | |
| experiment = cls( | |
| metadata=metadata_values, filename_fmt=filename_fmt, | |
| pre_start=times[0], pre_end=times[1], | |
| trial_start=times[2], trial_end=times[3], | |
| post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values), | |
| ) | |
| experiments.append(experiment) | |
| return experiments | |
| def parse_box_metadata( | |
| self, data_root: Path, box_names: tuple[str] = ("farside", "nearside"), | |
| cm_per_px: float | None = None, | |
| ): | |
| filename = self.box_metadata_filename(data_root) | |
| with open(filename, "r") as fh: | |
| data = json.load(fh) | |
| shape = None | |
| for s in data["shapes"]: | |
| for name in box_names: | |
| if s["label"] == name: | |
| shape = s | |
| if shape is None: | |
| raise ValueError(f"Cannot find {box_names} in the json file. {filename}") | |
| points = np.array(shape["points"]) | |
| min_x, min_y = np.min(points, axis=0) | |
| max_x, max_y = np.max(points, axis=0) | |
| w = max_x - min_x | |
| h = max_y - min_y | |
| self.box_center = int(min_x + w / 2), int(min_y + h / 2) | |
| self.box_size = int(w), int(h) | |
| self.image_size = int(data["imageWidth"]), int(data["imageHeight"]) | |
| self.cm_per_px = cm_per_px | |
| def set_box_metadata( | |
| self, box_left: int, box_top: int, box_right: int, box_bottom: int, image_size: tuple[int, int], | |
| cm_per_px: float | None = None, | |
| ): | |
| self.box_center = int((box_right + box_left) / 2), int((box_bottom + box_top) / 2) | |
| self.box_size = box_right - box_left, box_bottom - box_top | |
| self.image_size = image_size | |
| self.image_offset = 0, 0 | |
| self.enlarged_image_size = self.image_size | |
| self.largest_box_size = self.box_size | |
| self.cm_per_px = cm_per_px | |
| @classmethod | |
| def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]: | |
| box_centers = np.array([e.box_center for e in experiments]) | |
| box_sizes = np.array([e.box_size for e in experiments]) | |
| image_sizes = np.array([e.image_size for e in experiments]) | |
| max_image = np.max(image_sizes, axis=0) | |
| image_centers = np.floor(image_sizes / 2) | |
| adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2) | |
| adjusted_image_offsets = adjusted_image_centers - image_centers | |
| adjusted_box_centers = box_centers + adjusted_image_offsets | |
| mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0)) | |
| aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers | |
| # how much we need to offset our image data so our box center aligns with the mean box center of all boxes | |
| total_image_offset = adjusted_image_offsets + aligned_box_offsets | |
| min_offset = np.min(total_image_offset, axis=0) | |
| max_offset = np.max(total_image_offset, axis=0) | |
| final_image_size = max_image + max_offset - min_offset | |
| max_box_size = np.max(box_sizes, axis=0) | |
| for i, experiment in enumerate(experiments): | |
| experiment.image_offset = tuple(map(int, total_image_offset[i, :])) | |
| experiment.enlarged_image_size = tuple(map(int, final_image_size)) | |
| experiment.largest_box_size = tuple(map(int, max_box_size)) | |
| return tuple(map(int, final_image_size)) | |
| def position_to_side( | |
| self, x: int, y: int, split_horizontally: bool = True, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
| ): | |
| cx, cy = self.box_center | |
| cue_x, cue_y = self._descriptor_to_point(self.cue_point) | |
| if split_horizontally: | |
| if cue_x < cx: | |
| i = 0 if x < cx else 1 | |
| else: | |
| i = 0 if x >= cx else 1 | |
| else: | |
| if cue_y < cy: | |
| i = 0 if y < cy else 1 | |
| else: | |
| i = 0 if y >= cy else 1 | |
| return categoricals[i] | |
| def convert_to_categoricals( | |
| self, data_group: Union["SubjectPos", str], measure: CAT_MEASURE_TYPE, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
| measure_options: dict | None = None, | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| if measure_options is None: | |
| measure_options = {} | |
| if isinstance(data_group, str): | |
| data_group = getattr(self, data_group) | |
| if measure == "occupancy": | |
| f = partial( | |
| self.position_to_side, | |
| split_horizontally=measure_options.get("split_horizontally", True), categoricals=categoricals, | |
| ) | |
| times, time_diff, categoricals_index = data_group.transform_to_categorical_pos(f, categoricals) | |
| elif "freezing" in measure: | |
| freeze_i = categoricals.index(measure_options["freeze_name"]) | |
| non_freeze_i = 1 if freeze_i == 0 else 0 | |
| if measure == "motion_index_freezing": | |
| cm_per_px = data_group.cm_per_px or 1 | |
| index = data_group.motion_index[:-1] * cm_per_px | |
| valid = index >= 0 | |
| times = data_group.times | |
| time_diff = times[1:] - times[:-1] | |
| assert np.all(time_diff >= 0) | |
| index = index[valid] | |
| time_diff = time_diff[valid] | |
| times = times[:-1][valid] | |
| else: | |
| assert measure == "speed_freezing" | |
| times, index, _, time_diff = data_group.calculate_speed() | |
| categoricals_index = np.empty(len(times)) | |
| freezing = index <= measure_options["threshold"] | |
| categoricals_index[freezing] = freeze_i | |
| categoricals_index[np.logical_not(freezing)] = non_freeze_i | |
| else: | |
| raise ValueError(f"Unknown measure {measure}") | |
| return times, time_diff, categoricals_index | |
| def position_to_grid_index( | |
| self, x, y, grid_width: int, grid_height: int, y_at_top: bool = False | |
| ) -> tuple[int, int]: | |
| offset_x, offset_y = self.image_offset | |
| cx, cy = self.box_center | |
| bw, bh = self.box_size | |
| largest_bw, largest_bh = self.largest_box_size | |
| width_scale = largest_bw / bw | |
| height_scale = largest_bh / bh | |
| x = (x - cx) * width_scale + cx + offset_x | |
| y = (y - cy) * height_scale + cy + offset_y | |
| x = int(min(max(round(x), 0), grid_width - 1)) | |
| y = int(min(max(round(y), 0), grid_height - 1)) | |
| if y_at_top: | |
| y = grid_height - 1 - y | |
| return x, y | |
| def read_pos_track( | |
| self, data_root: Path, pre_offset: float = 0, pre_duration: float | None = None, trial_offset: float = 0, | |
| trial_duration: float | None = None, post_offset: float = 0, post_duration: float | None = None, | |
| frame_rate: float = 30. | |
| ) -> None: | |
| self.pos_data = SubjectPos.parse_csv_track( | |
| self.pos_path_filename(data_root), frame_rate=frame_rate, cm_per_px=self.cm_per_px | |
| ) | |
| start = self.pre_start + pre_offset | |
| end = self.pre_end | |
| if pre_duration and pre_duration > 0: | |
| end = min(end, start + pre_duration) | |
| elif pre_duration and pre_duration < 0: | |
| start = max(start, end + pre_duration) | |
| self.pre_pos_data = self.pos_data.extract_range(start, end) | |
| start = self.trial_start + trial_offset | |
| end = self.trial_end | |
| if trial_duration and trial_duration > 0: | |
| end = min(end, start + trial_duration) | |
| elif trial_duration and trial_duration < 0: | |
| start = max(start, end + trial_duration) | |
| self.trial_pos_data = self.pos_data.extract_range(start, end) | |
| start = self.post_start + post_offset | |
| end = self.post_end | |
| if post_duration and post_duration > 0: | |
| end = min(end, start + post_duration) | |
| elif post_duration and post_duration < 0: | |
| start = max(start, end + post_duration) | |
| self.post_pos_data = self.pos_data.extract_range(start, end) | |
| def chop_pos_track( | |
| self, hab_segments: list[tuple[float, float]] = (), trial_segments: list[tuple[float, float]] = (), | |
| post_segments: list[tuple[float, float]] = (), | |
| ) -> None: | |
| data = self.chopped_pos_data = {} | |
| for name in ("Pre", "Trial", "Post"): | |
| match name: | |
| case "Pre": | |
| segments = hab_segments | |
| ts = self.pre_start | |
| te = self.pre_end | |
| case "Trial": | |
| segments = trial_segments | |
| ts = self.trial_start | |
| te = self.trial_end | |
| case "Post": | |
| segments = post_segments | |
| ts = self.post_start | |
| te = self.post_end | |
| case _: | |
| assert False | |
| for start, end in segments: | |
| item = self.pos_data.extract_range(ts + start, min(ts + end, te)) | |
| key = f"{name}_{int(start)}" if len(segments) > 1 else name | |
| data[key] = item | |
| def extract_subject_period( | |
| self, hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (), | |
| post_segment: tuple[float, float] = (), | |
| ) -> Optional["SubjectPos"]: | |
| if hab_segment: | |
| segment = hab_segment | |
| ts = self.pre_start | |
| te = self.pre_end | |
| elif trial_segment: | |
| segment = trial_segment | |
| ts = self.trial_start | |
| te = self.trial_end | |
| elif post_segment: | |
| segment = post_segment | |
| ts = self.post_start | |
| te = self.post_end | |
| else: | |
| return None | |
| start, end = segment | |
| item = self.pos_data.extract_range(ts + start, min(ts + end, te)) | |
| return item | |
| @classmethod | |
| def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]: | |
| a, b, c = cls.triplet_name | |
| if obj is None: | |
| res = [ | |
| ("pre_pos_data", a), | |
| ("trial_pos_data", b), | |
| ("post_pos_data", c), | |
| ] | |
| else: | |
| res = [ | |
| (obj.pre_pos_data, a), | |
| (obj.trial_pos_data, b), | |
| (obj.post_pos_data, c), | |
| ] | |
| return res | |
| @classmethod | |
| def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| it = iter(axs.flatten()) | |
| for i, filter_group in enumerate(filter_args or [None, ]): | |
| if n_groups > 1: | |
| experiments_ = cls.filter(experiments, **filter_group) | |
| else: | |
| experiments_ = experiments | |
| for j, (data_name, t) in enumerate(periods): | |
| ax = next(it) | |
| yield experiments_, i, filter_group, j, data_name, t, ax | |
| def plot_occupancy( | |
| self, | |
| gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True, | |
| scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| show_point: bool = True, point: tuple[float | str, float | str] | None = None, | |
| ): | |
| if point is None: | |
| point = self.cue_point | |
| point = self._descriptor_to_point(point) | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
| unit = "cm" if self.cm_per_px else "pixels" | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| data.plot_occupancy( | |
| *self.enlarged_image_size, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax, | |
| gaussian_sigma=gaussian_sigma, | |
| intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one, | |
| color_bar=not i, | |
| x_label=f"X ({unit})", | |
| y_label=f"Y ({unit})" if not i else "", | |
| title="", | |
| ) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| if show_point: | |
| cm_per_px = data.cm_per_px or 1 | |
| x, y = self.position_to_grid_index(*point, *self.enlarged_image_size) | |
| ax.plot([x * cm_per_px], [y * cm_per_px], "*r") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} occupancy density") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_occupancy( | |
| cls, experiments: list["Experiment"], | |
| gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density", | |
| frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", | |
| show_point: bool = True, point: tuple[float | str, float | str] | None = None, | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| if not experiments_: | |
| continue | |
| unit = "cm" if experiments_[0].cm_per_px else "pixels" | |
| cm_per_px = experiments_[0].cm_per_px | |
| assert len({exp.cm_per_px for exp in experiments_}) == 1 | |
| grid_size = {e.enlarged_image_size for e in experiments_} | |
| if len(grid_size) != 1: | |
| raise ValueError("Trying to plot occupancy of experiments with different sizes") | |
| grid_size = list(grid_size)[0] | |
| occupancy = np.zeros(grid_size) | |
| for experiment in experiments_: | |
| getattr(experiment, data_name).calculate_occupancy( | |
| occupancy, experiment.position_to_grid_index, frame_normalize, | |
| ) | |
| if experiment_normalize: | |
| occupancy /= len(experiments_) | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| SubjectPos.plot_occupancy_data( | |
| occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, | |
| color_bar=not i and not j, | |
| title="", | |
| x_label=f"X ({unit})" if i == n_groups - 1 else "", | |
| y_label=f"{group}Y ({unit})" if not j else "", | |
| cm_per_px=cm_per_px, | |
| ) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| if show_point: | |
| exp_point = experiments_[0].cue_point if point is None else point | |
| exp_point = experiments_[0]._descriptor_to_point(exp_point) | |
| x, y = experiments_[0].position_to_grid_index(*exp_point, *grid_size) | |
| ax.plot([x * (cm_per_px or 1)], [y * (cm_per_px or 1)], "*r") | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def export_multi_experiment_motion( | |
| cls, experiments: list["Experiment"], filename: Path, | |
| measure: Literal["motion_index", "speed"] = "motion_index", use_chopped_data: bool = False, | |
| ) -> None: | |
| unit = "px" if experiments[0].cm_per_px is None else "cm" | |
| header = experiments[0].get_header_id_columns() + ["Stage", f"Motion mean ({unit}/s)"] | |
| filename.parent.mkdir(parents=True, exist_ok=True) | |
| with open(filename, "w", newline="") as fh: | |
| writer = csv.writer(fh, delimiter=",") | |
| writer.writerow(header) | |
| for experiment in experiments: | |
| cm_per_px = experiment.cm_per_px or 1 | |
| if use_chopped_data: | |
| stages = list(experiment.chopped_pos_data.keys()) | |
| stages_data = [experiment.chopped_pos_data[stage] for stage in stages] | |
| else: | |
| stages = "Pre", "Trial", "Post" | |
| stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
| id_columns = experiment.get_id_columns() | |
| for stage, data in zip(stages, stages_data): | |
| if measure == "motion_index": | |
| motion = data.motion_index * cm_per_px | |
| elif measure == "speed": | |
| _, motion, _, _ = data.calculate_speed() | |
| motion = motion[motion >= 0] | |
| mean_motion = np.mean(motion) if len(motion) else 0 | |
| line = id_columns + [stage, mean_motion] | |
| writer.writerow(map(str, line)) | |
| @classmethod | |
| def export_bouts( | |
| cls, experiments: list["Experiment"], filename: Path, measure: BOUT_MEASURE_TYPE, | |
| threshold: float, frame_rate: float, refractory_period: float = 0, | |
| downsample_factor: int = 1, use_chopped_data: bool = False, min_bout_frames: int = 0, | |
| unit: Literal["percent", "times"] = "times", | |
| ) -> None: | |
| header = experiments[0].get_header_id_columns() + [ | |
| "Stage", | |
| "Number of bouts" + ("" if unit == "times" else " / s"), | |
| "Mean duration (s)", | |
| "Total duration " + ("(s)" if unit == "times" else "(percent)"), | |
| ] | |
| filename.parent.mkdir(parents=True, exist_ok=True) | |
| with open(filename, "w", newline="") as fh: | |
| writer = csv.writer(fh, delimiter=",") | |
| writer.writerow(header) | |
| for experiment in experiments: | |
| if use_chopped_data: | |
| stages = list(experiment.chopped_pos_data.keys()) | |
| stages_data = [experiment.chopped_pos_data[stage] for stage in stages] | |
| else: | |
| stages = "Pre", "Trial", "Post" | |
| stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
| id_columns = experiment.get_id_columns() | |
| for stage, data in zip(stages, stages_data): | |
| bouts = data.calculate_measure_bout( | |
| measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames, | |
| ) | |
| dur = sum(b[1] for b in bouts) | |
| n = len(bouts) | |
| dur_mean = (dur / len(bouts)) if bouts else 0 | |
| if unit == "percent": | |
| dur = dur / data.duration * 100 | |
| n /= data.duration | |
| line = id_columns + [stage, n, dur_mean, dur] | |
| writer.writerow(map(str, line)) | |
| def plot_motion_index( | |
| self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = "" | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| data.plot_motion_index( | |
| fig, ax, **{"y_label": ""} if i else {}, | |
| ) | |
| if y_limit is not None: | |
| ax.set_ylim(0, y_limit) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} motion index") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_motion( | |
| cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", | |
| measure: Literal["motion_index", "speed"] = "motion_index", | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| for experiment in experiments_: | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| kwargs = {"y_label": ""} | |
| if i != n_groups - 1: | |
| kwargs["x_label"] = "" | |
| if not j: | |
| label = "Motion index" if measure == "motion_index" else "Speed" | |
| if experiment.cm_per_px: | |
| label += " (cm/s)" | |
| else: | |
| label += " (px/s)" | |
| kwargs["y_label"] = f"{group}{label}" | |
| data = getattr(experiment, data_name) | |
| getattr(data, f"plot_{measure}")(fig, ax, **kwargs) | |
| if y_limit is not None: | |
| ax.set_ylim(0, y_limit) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_motion_index_histogram( | |
| self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| hist_range: tuple[float, float] = (0, 3), log_x: bool = False, | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
| cm_per_px = self.cm_per_px or 1 | |
| unit = "px" if self.cm_per_px is None else "cm" | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| data = data.motion_index * cm_per_px | |
| if log_x: | |
| total = len(data) | |
| data = data[data > 0] | |
| if not len(data): | |
| continue | |
| non_zero = len(data) | |
| percent_zero = (total - non_zero) / total * 100 | |
| ax.hist(data, bins=get_log_bins(data, n_bins), density=True) | |
| ax.set_xscale("log") | |
| ax.set_title(f"$\\bf{{{title}}}$ zeros = {percent_zero:0.2f}%") | |
| else: | |
| data = data[data >= 0] | |
| if not len(data): | |
| continue | |
| ax.hist(data, bins=n_bins, density=True) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| ax.set_xlim(*hist_range) | |
| ax.set_xlabel(f"Motion index ({unit}/s)") | |
| if not i: | |
| ax.set_ylabel("Density") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} motion index density") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_motion_index_histogram( | |
| cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", hist_range: tuple[float, float] = (0, 3), | |
| filter_args: list[dict] | None = None, group_label: str = "", log_x: bool = False, | |
| ): | |
| if not experiments: | |
| return | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
| unit = "px" if experiments[0].cm_per_px is None else "cm" | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| if not experiments_: | |
| continue | |
| items = [] | |
| for experiment in experiments_: | |
| cm_per_px = experiment.cm_per_px or 1 | |
| data = getattr(experiment, data_name).motion_index * cm_per_px | |
| items.append(data[data >= 0]) | |
| if log_x: | |
| data = np.concatenate(items) | |
| total = len(data) | |
| data = data[data > 0] | |
| if not len(data): | |
| continue | |
| non_zero = len(data) | |
| percent_zero = (total - non_zero) / total * 100 | |
| ax.hist(data, bins=get_log_bins(data, n_bins), density=True) | |
| ax.set_xscale("log") | |
| if i: | |
| ax.set_title(f"zeros = {percent_zero:0.2f}%") | |
| else: | |
| ax.set_title(f"$\\bf{{{t}}}$ zeros = {percent_zero:0.2f}%") | |
| else: | |
| if items: | |
| data = np.concatenate(items) | |
| if not len(data): | |
| continue | |
| ax.hist(data, bins=n_bins, density=True) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| ax.set_xlim(*hist_range) | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| if i == n_groups - 1: | |
| ax.set_xlabel(f"Motion index ({unit}/s)") | |
| if not j: | |
| ax.set_ylabel(f"{group}Density") | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_speed( | |
| self, y_limit: float | None = None, save_fig_root: None | Path = None, | |
| save_fig_prefix: str = "" | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| data.plot_speed( | |
| fig, ax, **{"y_label": ""} if i else {}, | |
| ) | |
| if y_limit is not None: | |
| ax.set_ylim(0, y_limit) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} speed") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_speed_histogram( | |
| self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| hist_range: tuple[float, float] = (0, 200), log_x: bool = False, | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| _, data, _, _ = data.calculate_speed() | |
| if log_x: | |
| total = len(data) | |
| data = data[data > 0] | |
| if not len(data): | |
| continue | |
| non_zero = len(data) | |
| percent_zero = (total - non_zero) / total * 100 | |
| ax.hist(data, bins=get_log_bins(data, n_bins), density=True) | |
| ax.set_xscale("log") | |
| ax.set_title(f"$\\bf{{{title}}}$ zeros = {percent_zero:0.2f}%") | |
| else: | |
| data = data[data >= 0] | |
| if not len(data): | |
| continue | |
| ax.hist(data, bins=n_bins, density=True) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| ax.set_xlim(*hist_range) | |
| ax.set_xlabel("Speed (cm/s)" if self.cm_per_px else "Speed (px/s)") | |
| if not i: | |
| ax.set_ylabel("Density") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} Speed density") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_speed_histogram( | |
| cls, experiments: list["Experiment"], n_bins=100, | |
| title: str = "Subjects Speed", hist_range: tuple[float, float] = (0, 200), | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", log_x: bool = False, | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| if not experiments_: | |
| continue | |
| speed_label = "Speed (cm/s)" if experiments_ and experiments_[0].cm_per_px else "Speed (px/s)" | |
| assert len({exp.cm_per_px for exp in experiments_}) == 1 | |
| items = [] | |
| for experiment in experiments_: | |
| _, data, _, _ = getattr(experiment, data_name).calculate_speed() | |
| items.append(data[data >= 0]) | |
| if log_x: | |
| data = np.concatenate(items) | |
| total = len(data) | |
| data = data[data > 0] | |
| if not len(data): | |
| continue | |
| non_zero = len(data) | |
| percent_zero = (total - non_zero) / total * 100 | |
| ax.hist(data, bins=get_log_bins(data, n_bins), density=True) | |
| ax.set_xscale("log") | |
| if i: | |
| ax.set_title(f"zeros = {percent_zero:0.2f}%") | |
| else: | |
| ax.set_title(f"$\\bf{{{t}}}$ zeros = {percent_zero:0.2f}%") | |
| else: | |
| if items: | |
| data = np.concatenate(items) | |
| if len(data): | |
| ax.hist(data, bins=n_bins, density=True) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| ax.set_xlim(*hist_range) | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| if i == n_groups - 1: | |
| ax.set_xlabel(speed_label) | |
| if not j: | |
| ax.set_ylabel(f"{group}Density") | |
| fig.suptitle(f"{title}") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_categorical_values( | |
| self, measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None, | |
| save_fig_prefix: str = "", | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| for (data, title), ax in zip(self._get_data_items(self), axs.flatten()): | |
| times, _, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options) | |
| data.plot_categorical_values(times, categoricals_index, categoricals, fig, ax) | |
| ax.set_title(title) | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"Subject {label}") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_side_of_box( | |
| cls, experiments: list["Experiment"], measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()): | |
| for experiment in experiments: | |
| times, _, categoricals_index = experiment.convert_to_categoricals( | |
| data_name, measure, categoricals, measure_options | |
| ) | |
| getattr(experiment, data_name).plot_categorical_values(times, categoricals_index, categoricals, fig, ax) | |
| ax.set_title(t) | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def _descriptor_to_point(self, point: tuple[str, str] | tuple[float, float]): | |
| cx, cy = self.box_center | |
| bw, bh = self.box_size | |
| match point[0]: | |
| case "left": | |
| x = cx - bw / 2 | |
| case "right": | |
| x = cx + bw / 2 | |
| case "center": | |
| x = cx | |
| case int() | float(): | |
| x = point[0] | |
| case _: | |
| raise ValueError(f"Can't recognize {point[0]}") | |
| match point[1]: | |
| case "bottom": | |
| y = cy + bh / 2 | |
| case "top": | |
| y = cy - bh / 2 | |
| case "center": | |
| y = cy | |
| case int() | float(): | |
| y = point[1] | |
| case _: | |
| raise ValueError(f"Can't recognize {point[1]}") | |
| return x, y | |
| def plot_distance_from_point( | |
| self, point: tuple[str, str] | None = None, save_fig_root: None | Path = None, | |
| save_fig_prefix: str = "", post_title: str = "distance from teaball corner", | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| if point is None: | |
| point = self.cue_point | |
| point_xy = self._descriptor_to_point(point) | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| data.plot_distance_from_point( | |
| point_xy, fig, ax, **{"y_label": ""} if i else {}, | |
| ) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| label = self.title_root.format(**self.metadata) | |
| fig.suptitle(f"{label} {post_title}") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_distance_from_point( | |
| cls, experiments: list["Experiment"], point: tuple[str, str] | None = None, | |
| title: str = "Subjects distance from teaball corner", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| if not experiments_: | |
| continue | |
| dist_label = "cm" if experiments_ and experiments_[0].cm_per_px else "px" | |
| assert len({exp.cm_per_px for exp in experiments_}) == 1 | |
| for experiment in experiments_: | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| kwargs = {"y_label": ""} | |
| if i != n_groups - 1: | |
| kwargs["x_label"] = "" | |
| if not j: | |
| kwargs["y_label"] = f"{group}Distance ({dist_label})" | |
| point_xy = experiment._descriptor_to_point(experiment.cue_point if point is None else point) | |
| getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_bouts( | |
| self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, refractory_period: float = 0, | |
| downsample_factor: int = 1, min_bout_frames: int = 0, save_fig_root: None | Path = None, | |
| save_fig_prefix: str = "", plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter", | |
| unit: Literal["percent", "times"] = "times", | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| x_label = "Time (min)" if plot_type == "scatter" else "Bouts" | |
| match plot_type: | |
| case "scatter": | |
| y_label = "Duration (s)" | |
| case "count": | |
| y_label = "Count" if unit == "times" else "Count / s" | |
| case "mean": | |
| y_label = "Mean duration (s)" | |
| case "duration": | |
| y_label = "Total duration " + ("(s)" if unit == "times" else "(percent)") | |
| case _: | |
| y_label = "" | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| bouts = data.plot_bout( | |
| measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames, plot_type, fig, | |
| ax, x_label=x_label, norm_time=unit == "percent", **{"y_label": "" if i else y_label}, | |
| ) | |
| ax.set_title(f"$\\bf{{{title}}}$ {len(bouts)} bouts") | |
| label = self.title_root.format(**self.metadata) | |
| name = measure.capitalize().replace("_", " ") | |
| fig.suptitle(f"{label} {name} bouts") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_bouts( | |
| cls, experiments: list["Experiment"], measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, | |
| refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", | |
| plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter", unit: Literal["percent", "times"] = "times", | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
| x_label = "Time (min)" if plot_type == "scatter" else "Bouts" | |
| match plot_type: | |
| case "scatter": | |
| y_label = "Duration (s)" | |
| case "count": | |
| y_label = "Count" if unit == "times" else "Count / s" | |
| case "mean": | |
| y_label = "Mean duration (s)" | |
| case "duration": | |
| y_label = "Total duration " + ("(s)" if unit == "times" else "(percent)") | |
| case _: | |
| y_label = "" | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| if not experiments_: | |
| continue | |
| assert len({exp.cm_per_px for exp in experiments_}) == 1 | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| kwargs = {"y_label": "", "x_label": x_label} | |
| if i != n_groups - 1: | |
| kwargs["x_label"] = "" | |
| if not j: | |
| kwargs["y_label"] = f"{group}{y_label}" | |
| if plot_type == "scatter": | |
| for experiment in experiments_: | |
| getattr(experiment, data_name).plot_bout( | |
| measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames, | |
| plot_type, fig, ax, **kwargs | |
| ) | |
| else: | |
| all_bouts = [ | |
| getattr(experiment, data_name).calculate_measure_bout( | |
| measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames | |
| ) | |
| for experiment in experiments_ | |
| ] | |
| durations = [getattr(experiment, data_name).duration for experiment in experiments_] | |
| durations = np.array(durations, dtype=np.float64) | |
| match plot_type: | |
| case "count": | |
| values = np.array([len(bouts) for bouts in all_bouts], dtype=np.float64) | |
| if unit == "percent": | |
| values /= durations | |
| case "mean": | |
| values = np.array( | |
| [sum(b[1] for b in bouts) / len(bouts) if bouts else 0 for bouts in all_bouts], | |
| dtype=np.float64 | |
| ) | |
| case "duration": | |
| values = np.array([sum(b[1] for b in bouts) for bouts in all_bouts], dtype=np.float64) | |
| if unit == "percent": | |
| values = values / durations * 100 | |
| case _: | |
| assert False | |
| ax.bar([0], [np.mean(values)]) | |
| ax.plot([0] * len(values), values, "k.") | |
| ax.set_xticks([0], [""]) | |
| if kwargs["x_label"]: | |
| ax.set_xlabel(kwargs["x_label"]) | |
| if kwargs["y_label"]: | |
| ax.set_ylabel(kwargs["y_label"]) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| name = measure.capitalize().replace("_", " ") | |
| fig.suptitle(f"{name} bouts") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def _get_categorical_percent( | |
| cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray], | |
| sorted_categoricals: tuple[str, ...], unit: Literal["percent", "times"] = "times", | |
| ): | |
| if unit == "times": | |
| percents = np.empty((len(categoricals_index), len(sorted_categoricals))) | |
| for e, (exp_categoricals_index, time_diff) in enumerate(zip(categoricals_index, times_diff)): | |
| if len(exp_categoricals_index) != len(time_diff): | |
| raise ValueError("Provided time diff and categorical index are not the same length") | |
| for i in range(len(sorted_categoricals)): | |
| percents[e, i] = np.sum(time_diff[exp_categoricals_index == i]) | |
| elif unit == "percent": | |
| counts = np.array([ | |
| [np.sum(arr == i) for i in range(len(sorted_categoricals))] | |
| for arr in categoricals_index | |
| ]) | |
| percents = counts / np.sum(counts, axis=1, keepdims=True) * 100 | |
| else: | |
| raise ValueError(f"Unknown unit {unit}") | |
| mean_prop = np.mean(percents, axis=0).squeeze() | |
| return mean_prop, [prop.squeeze() for prop in percents] | |
| @classmethod | |
| def _plot_percents( | |
| cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray], | |
| sorted_categoricals: tuple[str, ...], | |
| fig: plt.Figure, ax: plt.Axes, | |
| x_label: str = "Teaball side", y_label: str = "% time spent", | |
| unit: Literal["percent", "times"] = "times", | |
| ): | |
| mean_prop, percents = cls._get_categorical_percent(categoricals_index, times_diff, sorted_categoricals, unit) | |
| ax.bar(np.arange(len(sorted_categoricals)), mean_prop, tick_label=sorted_categoricals) | |
| if len(categoricals_index) > 1: | |
| for prop in percents: | |
| ax.plot(np.arange(len(sorted_categoricals)), prop.squeeze(), "k.") | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| def plot_categorical_percent( | |
| self, measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| unit: Literal["percent", "times"] = "times", x_label="Odor side", action_label: str = "spent in side", | |
| ): | |
| fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
| y_label = "% time spent" if unit == "percent" else "total time spent (s)" | |
| for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
| _, time_diff, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, | |
| measure_options) | |
| self._plot_percents( | |
| [categoricals_index], [time_diff], categoricals, fig, ax, x_label, "" if i else y_label, unit, | |
| ) | |
| ax.set_title(f"$\\bf{{{title}}}$") | |
| label = self.title_root.format(**self.metadata) | |
| tp = "total" if unit == "times" else "%" | |
| fig.suptitle(f"{label} {tp} time {action_label}") | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_categorical_percent( | |
| cls, experiments: list["Experiment"], measure: CAT_MEASURE_TYPE = "occupancy", | |
| measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| filter_args: list[dict] | None = None, group_label: str = "", | |
| unit: Literal["percent", "times"] = "times", x_label: str = "Odor side", | |
| ): | |
| n_groups = 1 if filter_args is None else len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
| for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
| experiments, periods, axs, filter_args): | |
| categoricals_index_all = [] | |
| times_all = [] | |
| for experiment in experiments_: | |
| _, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
| data_name, measure, categoricals, measure_options | |
| ) | |
| categoricals_index_all.append(categoricals_index) | |
| times_all.append(times_diff) | |
| if not times_all: | |
| continue | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| kwargs = {"y_label": "", "x_label": x_label} | |
| if i != n_groups - 1: | |
| kwargs["x_label"] = "" | |
| if not j: | |
| label = "% time spent" if unit == "percent" else "total time spent (s)" | |
| kwargs["y_label"] = f"{group}{label}" | |
| cls._plot_percents(categoricals_index_all, times_all, categoricals, fig, ax, **kwargs, unit=unit) | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def export_multi_experiment_categorical_percent( | |
| cls, experiments: list["Experiment"], filename: Path, | |
| measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), | |
| unit: Literal["percent", "times"] = "times", | |
| use_chopped_data: bool = False, | |
| ) -> None: | |
| sign = "sec" if unit == "times" else "%" | |
| header = experiments[0].get_header_id_columns() + ["Stage"] + [ | |
| f"{label} {sign}" for label in categoricals | |
| ] | |
| filename.parent.mkdir(parents=True, exist_ok=True) | |
| with open(filename, "w", newline="") as fh: | |
| writer = csv.writer(fh, delimiter=",") | |
| writer.writerow(header) | |
| for experiment in experiments: | |
| if use_chopped_data: | |
| stages = list(experiment.chopped_pos_data.keys()) | |
| stages_data = [experiment.chopped_pos_data[stage] for stage in stages] | |
| else: | |
| stages = "Pre", "Trial", "Post" | |
| stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
| id_columns = experiment.get_id_columns() | |
| for stage, data in zip(stages, stages_data): | |
| _, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
| data, measure, categoricals, measure_options | |
| ) | |
| mean_prop, _ = cls._get_categorical_percent([categoricals_index], [times_diff], categoricals, unit) | |
| line = id_columns + [stage, *mean_prop] | |
| writer.writerow(map(str, line)) | |
| @classmethod | |
| def plot_multi_experiment_merged_by_period_categorical_percent( | |
| cls, experiments: list["Experiment"], filter_args: list[dict], | |
| measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", group_label: str = "", | |
| unit: Literal["percent", "times"] = "times", x_label: str = "Teaball side", | |
| ): | |
| n_groups = len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True) | |
| axes = list(axs.flatten()) | |
| bar_width = 1 / (n_groups + 1) | |
| n_categoricals = len(categoricals) | |
| for i, filter_group in enumerate(filter_args): | |
| experiments_ = cls.filter(experiments, **filter_group) | |
| for j, (data_name, t) in enumerate(periods): | |
| categoricals_index_all = [] | |
| times_all = [] | |
| for experiment in experiments_: | |
| _, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
| data_name, measure, categoricals, measure_options | |
| ) | |
| categoricals_index_all.append(categoricals_index) | |
| times_all.append(times_diff) | |
| if not times_all: | |
| continue | |
| ax = axes[j] | |
| mean_prop, percents = cls._get_categorical_percent( | |
| categoricals_index_all, times_all, categoricals, unit | |
| ) | |
| ax.bar( | |
| i * bar_width + np.arange(n_categoricals), mean_prop, bar_width, | |
| label=group_label.format(**filter_group), | |
| ) | |
| for prop in percents: | |
| ax.plot(i * bar_width + np.arange(n_categoricals), prop, "k.") | |
| ax.set_xlabel(x_label) | |
| if not j: | |
| label = "% time spent" if unit == "percent" else "total time spent (s)" | |
| ax.set_ylabel(f"{label}") | |
| if not i: | |
| ax.set_title(f"$\\bf{{{t}}}$") | |
| for ax in axes: | |
| ax.set_xticks(np.arange(n_categoricals) + (1 - bar_width) / 2 - bar_width / 2, categoricals) | |
| ax.set_xlim(-bar_width, n_categoricals + bar_width) | |
| handles, labels = axes[-1].get_legend_handles_labels() | |
| fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2) | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| @classmethod | |
| def plot_multi_experiment_merged_by_group_categorical_percent( | |
| cls, experiments: list["Experiment"], filter_args: list[dict], | |
| measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None, | |
| categoricals: tuple[str, ...] = ("Teaball-side", "Far-side"), title: str = "Subjects motion", | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True, | |
| show_legend: bool = True, only_categorical: str | None = None, show_xlabel: bool = True, | |
| unit: Literal["percent", "times"] = "times", group_label: str = "", x_label: str = "Teaball side" | |
| ): | |
| n_groups = len(filter_args) | |
| periods = cls._get_data_items() | |
| n_periods = len(periods) | |
| fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True) | |
| axes = list(axs.flatten()) | |
| bar_width = 1 / (n_periods + 1) | |
| n_categoricals = len(categoricals) | |
| categoricals_s = slice(0, n_categoricals) | |
| n_categoricals_used = n_categoricals | |
| categoricals_used = categoricals | |
| if only_categorical: | |
| i = categoricals.index(only_categorical) | |
| categoricals_s = slice(i, i + 1) | |
| n_categoricals_used = 1 | |
| categoricals_used = [only_categorical] | |
| for i, filter_group in enumerate(filter_args): | |
| group = "" | |
| if n_groups > 1: | |
| group = group_label.format(**filter_group) | |
| experiments_ = cls.filter(experiments, **filter_group) | |
| for j, (data_name, t) in enumerate(periods): | |
| categoricals_index_all = [] | |
| times_all = [] | |
| for experiment in experiments_: | |
| _, times_diff, categoricals_index = experiment.convert_to_categoricals( | |
| data_name, measure, categoricals, measure_options | |
| ) | |
| categoricals_index_all.append(categoricals_index) | |
| times_all.append(times_diff) | |
| if not times_all: | |
| continue | |
| ax = axes[i] | |
| mean_prop, percents = cls._get_categorical_percent( | |
| categoricals_index_all, times_all, categoricals, unit | |
| ) | |
| ax.bar(j * bar_width + np.arange(n_categoricals_used), mean_prop[categoricals_s], bar_width, label=t) | |
| for prop in percents: | |
| ax.plot(j * bar_width + np.arange(n_categoricals_used), prop[categoricals_s], "k.") | |
| if i == n_groups - 1 and show_xlabel: | |
| ax.set_xlabel(x_label) | |
| label = "% time spent" if unit == "percent" else "total time spent (s)" | |
| ax.set_ylabel(f"{group}{label}") | |
| if not j and show_titles: | |
| ax.set_title(group_label.format(**filter_group)) | |
| for ax in axes: | |
| ax.set_xticks(np.arange(n_categoricals_used) + (1 - bar_width) / 2 - bar_width / 2, categoricals_used) | |
| ax.set_xlim(-bar_width, n_categoricals_used - bar_width) | |
| if show_legend: | |
| handles, labels = axes[-1].get_legend_handles_labels() | |
| fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2) | |
| fig.suptitle(title) | |
| save_or_show(save_fig_root, save_fig_prefix, width_inch=3) | |
| @classmethod | |
| def export_multi_experiment_frames(cls, experiments: list["Experiment"], filename: Path) -> None: | |
| header = experiments[0].get_header_id_columns() + [ | |
| "box_center_x_px", "box_center_y_px", "box_width_px", "box_height_px", "image_width_px", "image_height_px", | |
| "cm_per_px", "total_frames", "pre_frames", "trial_frames", "post_frames", "total_duration", "pre_duration", | |
| "trial_duration", "post_duration", | |
| ] | |
| filename.parent.mkdir(parents=True, exist_ok=True) | |
| with open(filename, "w", newline="") as fh: | |
| writer = csv.writer(fh, delimiter=",") | |
| writer.writerow(header) | |
| for experiment in experiments: | |
| data_obj = ( | |
| experiment.pos_data, experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data | |
| ) | |
| line = experiment.get_id_columns() + [ | |
| *experiment.box_center, | |
| *experiment.box_size, | |
| *experiment.image_size, | |
| experiment.cm_per_px, | |
| *(len(obj.times) for obj in data_obj), | |
| *((obj.times[-1] - obj.times[0]) if len(obj.times) else 0 for obj in data_obj), | |
| ] | |
| writer.writerow(map(str, line)) | |
| @classmethod | |
| def filter(cls, experiments: list["Experiment"], **metadata): | |
| for key, value in metadata.items(): | |
| if value is not None: | |
| experiments = [t for t in experiments if t.metadata[key] == value] | |
| return experiments | |
| @classmethod | |
| def count_motion_range( | |
| cls, experiments: list["Experiment"], measure: Literal["speed", "motion_index"] = "motion_index", | |
| downsample_factor: int = 1, hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (), | |
| post_segment: tuple[float, float] = (), | |
| ) -> dict[float, int]: | |
| items = [] | |
| for exp in experiments: | |
| data = exp.pos_data | |
| if hab_segment or trial_segment or post_segment: | |
| data = exp.extract_subject_period(hab_segment, trial_segment, post_segment) | |
| cm_per_px = exp.cm_per_px or 1 | |
| if measure == "motion_index": | |
| motion = data.motion_index * cm_per_px | |
| else: | |
| _, motion, _, _ = data.calculate_speed() | |
| if downsample_factor != 1: | |
| motion = decimate(motion, downsample_factor) | |
| items.append(motion[motion >= 0]) | |
| data = np.concatenate(items) | |
| n = len(data) | |
| counts = {} | |
| zero_mask = data == 0 | |
| counts[0] = np.sum(zero_mask).item() | |
| data = data[np.logical_not(zero_mask)] | |
| if not len(data): | |
| return counts | |
| min_val = np.min(data) | |
| assert min_val > 0 | |
| max_val = np.max(data) | |
| min_order = int(math.floor(math.log10(min_val))) | |
| max_order = max(int(math.floor(math.log10(max_val))), 0) | |
| for order in range(max_order, min_order - 1, -1): | |
| for offset in (0, math.log10(0.5)): | |
| val = math.pow(10, order + offset) | |
| mask = data >= val | |
| counts[val] = np.sum(mask).item() | |
| data = data[np.logical_not(mask)] | |
| assert not len(data) | |
| assert sum(counts.values()) == n | |
| return counts | |
| @classmethod | |
| def plot_motion_index_range( | |
| cls, experiments: list["Experiment"], ax: plt.Axes = None, | |
| measure: Literal["speed", "motion_index"] = "motion_index", | |
| ) -> None: | |
| has_ax = ax is not None | |
| items = [] | |
| for exp in experiments: | |
| cm_per_px = exp.cm_per_px or 1 | |
| if measure == "motion_index": | |
| motion = exp.pos_data.motion_index * cm_per_px | |
| else: | |
| _, motion, _, _ = exp.pos_data.calculate_speed() | |
| items.append(motion[motion > 0]) | |
| data = np.concatenate(items) | |
| if not has_ax: | |
| fig, ax = plt.subplots() | |
| ax.hist(data, bins=get_log_bins(data, 100), density=True) | |
| ax.set_xscale("log") | |
| ax.set_xlabel("Motion index" if measure == "motion_index" else "Speed") | |
| ax.set_ylabel("Density") | |
| if not has_ax: | |
| plt.show() | |
| @classmethod | |
| def plot_compare_motion_index_range( | |
| cls, data1: Path, data2: Path, title1: str, title2: str, frame_rate, | |
| measure: Literal["speed", "motion_index"] = "motion_index", cm_per_px: float | None = None, | |
| ) -> None: | |
| fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True) | |
| pos_data = SubjectPos.parse_csv_track(data1, frame_rate=frame_rate, cm_per_px=cm_per_px) | |
| exp = Experiment(0, 0, 0, 0, 0, 0, "", "") | |
| exp.pos_data = pos_data | |
| Experiment.plot_motion_index_range([exp], ax=ax1, measure=measure) | |
| ax1.set_title(title1) | |
| pos_data = SubjectPos.parse_csv_track(data2, frame_rate=frame_rate, cm_per_px=cm_per_px) | |
| exp = Experiment(0, 0, 0, 0, 0, 0, "", "") | |
| exp.pos_data = pos_data | |
| Experiment.plot_motion_index_range([exp], ax=ax2, measure=measure) | |
| ax2.set_title(title2) | |
| plt.show() | |
| @classmethod | |
| def percentile_motion_range( | |
| cls, subjects: list["SubjectPos"], percentiles: float | Sequence[float], | |
| measure: Literal["speed", "motion_index"] = "motion_index", downsample_factor: int = 1, | |
| ) -> float | np.ndarray | np.floating: | |
| items = [] | |
| for subject in subjects: | |
| cm_per_px = subject.cm_per_px or 1 | |
| if measure == "motion_index": | |
| motion = subject.motion_index * cm_per_px | |
| else: | |
| _, motion, _, _ = subject.calculate_speed() | |
| if downsample_factor != 1: | |
| motion = decimate(motion, downsample_factor) | |
| items.append(motion[motion >= 0]) | |
| data = np.concatenate(items) | |
| return np.percentile(data, percentiles) | |
| def get_y_range_from_data( | |
| self, measure: Literal["motion_index", "speed"], ylog: bool, | |
| ylim: tuple[float, float] | None | |
| ) -> tuple[float, float]: | |
| if ylim: | |
| return ylim | |
| cm = self.cm_per_px or 1 | |
| low = float("inf") | |
| high = float("-inf") | |
| for data in (self.pre_pos_data, self.trial_pos_data, self.post_pos_data): | |
| index = data.motion_index * cm if measure == "motion_index" else data.calculate_speed(True)[1] | |
| if ylog: | |
| index = index[index > 0] | |
| else: | |
| index = index[index >= 0] | |
| low = min(np.min(index), low) | |
| high = max(np.max(index), high) | |
| if ylog: | |
| low = 10 ** (math.floor(2 * math.log10(low)) / 2) | |
| high = 10 ** (math.ceil(2 * math.log10(high)) / 2) | |
| else: | |
| low = round(low) | |
| high = round(high) | |
| return low, high | |
| def generate_images_from_data( | |
| self, measure: Literal["motion_index", "speed"], ylog: bool, | |
| ylim: tuple[float, float] | None, frame_rate: float, time_window_padding: float, | |
| image_size: tuple[int, int], output_fig_h: int, root_output: Path, | |
| bouts_parameters: dict[str, dict[str, float]], thresholds: dict[str, float], | |
| ) -> str: | |
| root_output = root_output / measure | |
| root_output.mkdir(parents=True, exist_ok=True) | |
| cm = self.cm_per_px or 1 | |
| low, high = self.get_y_range_from_data(measure, ylog, ylim) | |
| data = self.pos_data | |
| index = data.motion_index * cm if measure == "motion_index" else data.calculate_speed(True)[1] | |
| freezing_bouts = data.calculate_bouts( | |
| index, above_threshold=False, frame_rate=frame_rate, **bouts_parameters[f"{measure}_freezing"], | |
| ) | |
| dart_bouts = data.calculate_bouts( | |
| index, above_threshold=True, frame_rate=frame_rate, **bouts_parameters[f"{measure}_dart"], | |
| ) | |
| freezing_threshold = thresholds[f"{measure}_freezing"] | |
| dart_threshold = thresholds[f"{measure}_dart"] | |
| froze = index <= freezing_threshold | |
| darted = index >= dart_threshold | |
| n = len(index) | |
| pad_frames = int(time_window_padding * frame_rate) | |
| all_times = np.arange(n) / frame_rate | |
| padding = int(math.ceil(math.log10(n))) | |
| fig, ax = plt.subplots() | |
| dpi = 150 | |
| fig.set_dpi(dpi) | |
| fig.set_size_inches(image_size[0] / dpi, output_fig_h / dpi) | |
| unit = "px" if not self.cm_per_px else "cm" | |
| ax.set_ylabel( | |
| measure.capitalize().replace("_", " ") + f" ({unit}/s)" | |
| ) | |
| ax.set_ylim(low, high) | |
| if ylog: | |
| ax.set_yscale("log") | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.spines['bottom'].set_visible(False) | |
| ax.xaxis.set_visible(False) | |
| ax.plot(all_times, index, "k") | |
| ax.plot(all_times[froze], index[froze], ".r") | |
| ax.plot(all_times[darted], index[darted], ".g") | |
| for bouts, y, c in [(freezing_bouts, low, "r"), (dart_bouts, high, "g")]: | |
| for ts, duration in bouts: | |
| ax.plot([ts, ts + duration], [y, y], f"-{c}") | |
| text = ax.text( | |
| 1, 1.15, f"Y:", | |
| horizontalalignment='right', | |
| verticalalignment='top', | |
| transform=ax.transAxes | |
| ) | |
| for i in tqdm.tqdm(range(n), total=n, unit="frames", desc=f"{self.title_root} - {measure}"): | |
| filename = root_output / f"{self.title_root}_{i:0{padding}}.png" | |
| if filename.exists(): | |
| continue | |
| value = index[i].item() | |
| ts = (i - pad_frames) / frame_rate | |
| te = (i + pad_frames) / frame_rate | |
| lines2 = ax.plot([all_times[i]], [value], "*b") | |
| text.set_text(f"Y: {value:0.5f}") | |
| ax.set_ylim(low, high) | |
| ax.set_xlim(ts, te) | |
| fig.tight_layout() | |
| fig.savefig(filename, dpi=dpi) | |
| for lines in (lines2,): | |
| for line in lines: | |
| line.remove() | |
| return str(root_output / f"{self.title_root}_%0{padding}d.png") | |
| def merge_video_with_images( | |
| self, parameters: "ExperimentParameters", frame_rate: float, image_size: tuple[int, int], | |
| output_fig_h: int, | |
| time_window_padding: float, video_filename: str | Path, ffmpeg: str | Path, temp_images: bool, | |
| above_measure: Literal["motion_index", "speed"] | None, | |
| below_measure: Literal["motion_index", "speed"] | None, | |
| bouts_parameters: dict[str, dict[str, float]], thresholds: dict[str, float], | |
| ): | |
| video_root = parameters.video_root | |
| video_root.mkdir(parents=True, exist_ok=True) | |
| tmp_dir = None | |
| if temp_images: | |
| tmp_dir = TemporaryDirectory() | |
| image_root = Path(tmp_dir.name) | |
| else: | |
| image_root = parameters.image_root | |
| image_root = image_root / self.title_root | |
| image_root.mkdir(parents=True, exist_ok=True) | |
| try: | |
| above_arg = () | |
| if above_measure is not None: | |
| img_pat = self.generate_images_from_data( | |
| above_measure, True, None, frame_rate, time_window_padding, image_size, | |
| output_fig_h, image_root, bouts_parameters, thresholds, | |
| ) | |
| above_arg = "-r", str(frame_rate), "-i", img_pat | |
| below_arg = () | |
| if below_measure is not None: | |
| img_pat = self.generate_images_from_data( | |
| below_measure, True, None, frame_rate, time_window_padding, image_size, | |
| output_fig_h, image_root, bouts_parameters, thresholds, | |
| ) | |
| below_arg = "-r", str(frame_rate), "-i", img_pat | |
| if above_measure and below_measure: | |
| filter_arg = "[0:v][1:v][2:v]vstack=inputs=3[v]" | |
| elif above_measure or below_measure: | |
| filter_arg = "[0:v][1:v]vstack=inputs=2[v]" | |
| else: | |
| filter_arg = "[0:v]vstack=inputs=1[v]" | |
| args = [ | |
| str(ffmpeg), | |
| *above_arg, | |
| "-i", | |
| str(video_filename), | |
| *below_arg, | |
| "-filter_complex", | |
| filter_arg, | |
| "-map", | |
| "[v]", | |
| "-c:v", | |
| "libx264", | |
| "-crf", | |
| "18", | |
| str(video_root / f"{self.title_root}_measures.mp4"), | |
| ] | |
| print(f'Command: {" ".join(args)}') | |
| p = subprocess.run(args, capture_output=True) | |
| finally: | |
| if tmp_dir is not None: | |
| tmp_dir.cleanup() | |
| print("stdout:\n") | |
| print(p.stdout.decode()) | |
| print("stderr:\n") | |
| print(p.stderr.decode()) | |
| p.check_returncode() | |
| @classmethod | |
| def export_experiments_csv( | |
| cls, experiments: list["Experiment"], data_root: Path, occupancy_meas_args, cue_categoricals, | |
| freezing_index_meas_args, freezing_speed_meas_args, freeze_categoricals, | |
| cat_unit: Literal["times", "percent"], include_chopped: bool, | |
| frame_rate: float, bouts_parameters: dict[str, dict[str, Any]], | |
| ): | |
| cls.export_multi_experiment_frames(experiments, data_root / "experiment_metadata.csv") | |
| cls.export_multi_experiment_categorical_percent( | |
| experiments, data_root / "box_sides.csv", measure="occupancy", measure_options=occupancy_meas_args, | |
| categoricals=cue_categoricals, unit=cat_unit, | |
| ) | |
| cls.export_multi_experiment_categorical_percent( | |
| experiments, | |
| data_root / f"motion_index_freezing_threshold_{freezing_index_meas_args['threshold']:0.5f}.csv", | |
| measure="motion_index_freezing", | |
| measure_options=freezing_index_meas_args, categoricals=freeze_categoricals, unit=cat_unit, | |
| ) | |
| cls.export_multi_experiment_categorical_percent( | |
| experiments, data_root / f"speed_freezing_threshold_{freezing_speed_meas_args['threshold']:0.5f}.csv", | |
| measure="speed_freezing", | |
| measure_options=freezing_speed_meas_args, | |
| categoricals=freeze_categoricals, unit=cat_unit, | |
| ) | |
| cls.export_multi_experiment_motion(experiments, data_root / "motion_index_mean.csv", measure="motion_index") | |
| cls.export_multi_experiment_motion(experiments, data_root / "speed_mean.csv", measure="speed") | |
| cls.export_bouts( | |
| experiments, data_root / "bouts_speed_freezing.csv", "speed_freezing", unit=cat_unit, | |
| frame_rate=frame_rate, **bouts_parameters["speed_freezing"], | |
| ) | |
| cls.export_bouts( | |
| experiments, data_root / "bouts_motion_index_freezing.csv", "motion_index_freezing", unit=cat_unit, | |
| frame_rate=frame_rate, **bouts_parameters["motion_index_freezing"], | |
| ) | |
| cls.export_bouts( | |
| experiments, data_root / "bouts_motion_index_dart.csv", "motion_index_dart", unit=cat_unit, | |
| frame_rate=frame_rate, **bouts_parameters["motion_index_dart"], | |
| ) | |
| cls.export_bouts( | |
| experiments, data_root / "bouts_speed_dart.csv", "speed_dart", frame_rate=frame_rate, unit=cat_unit, | |
| **bouts_parameters["speed_dart"], | |
| ) | |
| if include_chopped: | |
| cls.export_multi_experiment_categorical_percent( | |
| experiments, data_root / "box_sides_pre.csv", use_chopped_data=True, | |
| measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit, | |
| ) | |
| cls.export_multi_experiment_motion( | |
| experiments, data_root / "motion_index_mean_pre.csv", use_chopped_data=True, | |
| ) | |
| cls.export_multi_experiment_motion( | |
| experiments, data_root / "speed_mean_pre.csv", use_chopped_data=True, measure="speed", | |
| ) | |
| @classmethod | |
| def export_single_subject_figures( | |
| cls, experiments: list["Experiment"], figure_root: Path, occupancy_meas_args, | |
| cat_unit: Literal["times", "percent"], prefix_pat: str, frame_rate: float, | |
| bouts_parameters: dict[str, dict[str, Any]], | |
| ): | |
| for i, experiment in tqdm.tqdm(list(enumerate(experiments)), desc="subject"): | |
| label = prefix_pat.format(**experiment.metadata) | |
| experiment.plot_occupancy( | |
| scale_to_one=False, intensity_limit=1e-5, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"intensity_limit100K_{label}", | |
| ) | |
| experiment.plot_occupancy( | |
| intensity_limit=1e-6, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"intensity_limit1M_{label}", | |
| ) | |
| experiment.plot_occupancy( | |
| gaussian_sigma=5, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"occupancy_sigma5_{label}", | |
| ) | |
| experiment.plot_occupancy( | |
| gaussian_sigma=1, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"occupancy_sigma1_{label}", | |
| ) | |
| experiment.plot_motion_index_histogram( | |
| n_bins=50, hist_range=(0, .3), | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"motion_index_histogram_{label}", | |
| ) | |
| experiment.plot_motion_index_histogram( | |
| n_bins=25, hist_range=(1e-5, .3), | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"motion_index_histogram_log_{label}", log_x=True, | |
| ) | |
| experiment.plot_motion_index( | |
| y_limit=.2, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"motion_index_{label}", | |
| ) | |
| experiment.plot_speed_histogram( | |
| n_bins=50, hist_range=(0, 30), | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"speed_histogram_{label}", | |
| ) | |
| experiment.plot_speed_histogram( | |
| n_bins=25, hist_range=(1e-3, 30), | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"speed_histogram_log_{label}", log_x=True, | |
| ) | |
| experiment.plot_speed( | |
| y_limit=20, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"speed_{label}", | |
| ) | |
| experiment.plot_distance_from_point( | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"distance_{label}", | |
| ) | |
| experiment.plot_categorical_percent( | |
| measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit, | |
| save_fig_root=figure_root / "subject" / label, | |
| save_fig_prefix=f"side_{label}", x_label="", | |
| ) | |
| for plot_type in ["scatter", "count", "mean", "duration"]: | |
| experiment.plot_bouts( | |
| "speed_freezing", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit, | |
| save_fig_root=figure_root / "subject" / label / "bouts" / plot_type, | |
| save_fig_prefix=f"speed_freezing_{label}", **bouts_parameters["speed_freezing"], | |
| ) | |
| experiment.plot_bouts( | |
| "motion_index_freezing", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit, | |
| save_fig_root=figure_root / "subject" / label / "bouts" / plot_type, | |
| save_fig_prefix=f"motion_index_freezing_{label}", | |
| **bouts_parameters["motion_index_freezing"], | |
| ) | |
| experiment.plot_bouts( | |
| "motion_index_dart", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit, | |
| save_fig_root=figure_root / "subject" / label / "bouts" / plot_type, | |
| save_fig_prefix=f"motion_index_dart_{label}", **bouts_parameters["motion_index_dart"], | |
| ) | |
| experiment.plot_bouts( | |
| "speed_dart", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit, | |
| save_fig_root=figure_root / "subject" / label / "bouts" / plot_type, | |
| save_fig_prefix=f"speed_dart_{label}", **bouts_parameters["speed_dart"], | |
| ) | |
| @classmethod | |
| def export_multi_experiment_figures( | |
| cls, experiments: list["Experiment"], figure_root: Path, occupancy_meas_args, | |
| cat_unit: Literal["times", "percent"], label: str, cue_categoricals, freezing_index_meas_args, | |
| freezing_speed_meas_args, freeze_categoricals, frame_rate: float, | |
| bouts_parameters: dict[str, dict[str, Any]], | |
| filter_args: list[dict] | None = None, group_label: str = "", sub_dir: str = "grouped", | |
| ): | |
| cls.plot_multi_experiment_occupancy( | |
| experiments, | |
| title=f"{label} occupancy density", scale_to_one=False, | |
| intensity_limit=1e-5, filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "occupancy" / "100K", | |
| save_fig_prefix=f"occupancy_{label}_intensity_limit100K", | |
| ) | |
| cls.plot_multi_experiment_occupancy( | |
| experiments, filter_args=filter_args, group_label=group_label, | |
| title=f"{label} occupancy density", intensity_limit=1e-6, | |
| save_fig_root=figure_root / sub_dir / "occupancy" / "1M", | |
| save_fig_prefix=f"occupancy_{label}_intensity_limit1M", | |
| ) | |
| cls.plot_multi_experiment_occupancy( | |
| experiments, title=f"{label} occupancy", gaussian_sigma=5, | |
| save_fig_root=figure_root / sub_dir / "occupancy" / "sigma5", | |
| save_fig_prefix=f"occupancy_{label}_sigma5", | |
| filter_args=filter_args, group_label=group_label, | |
| ) | |
| cls.plot_multi_experiment_occupancy( | |
| experiments, title=f"{label} occupancy", gaussian_sigma=1, | |
| save_fig_root=figure_root / sub_dir / "occupancy" / "sigma1", | |
| save_fig_prefix=f"occupancy_{label}_sigma1", | |
| filter_args=filter_args, group_label=group_label, | |
| ) | |
| cls.plot_multi_experiment_motion_index_histogram( | |
| experiments, title=f"{label} motion index density", n_bins=50, filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(0, .3), | |
| save_fig_prefix=f"motion_index_histogram_{label}", | |
| ) | |
| cls.plot_multi_experiment_motion_index_histogram( | |
| experiments, title=f"{label} motion index density", n_bins=25, filter_args=filter_args, | |
| group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(1e-5, .3), | |
| save_fig_prefix=f"motion_index_histogram_log_{label}", log_x=True, | |
| ) | |
| cls.plot_multi_experiment_motion( | |
| experiments, y_limit=.2, title=f"{label} motion index", measure="motion_index", | |
| save_fig_root=figure_root / sub_dir / "motion_index", | |
| save_fig_prefix=f"motion_index_{label}", | |
| filter_args=filter_args, group_label=group_label, | |
| ) | |
| cls.plot_multi_experiment_speed_histogram( | |
| experiments, title=f"{label} speed density", n_bins=50, filter_args=filter_args, group_label=group_label, | |
| hist_range=(0, 30), | |
| save_fig_root=figure_root / sub_dir / "speed_histogram", | |
| save_fig_prefix=f"speed_histogram_{label}", | |
| ) | |
| cls.plot_multi_experiment_speed_histogram( | |
| experiments, title=f"{label} speed density", n_bins=25, filter_args=filter_args, group_label=group_label, | |
| hist_range=(1e-3, 30), | |
| save_fig_root=figure_root / sub_dir / "speed_histogram", | |
| save_fig_prefix=f"speed_histogram_log_{label}", log_x=True, | |
| ) | |
| cls.plot_multi_experiment_motion( | |
| experiments, y_limit=20, title=f"{label} speed", measure="speed", | |
| filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "speed", | |
| save_fig_prefix=f"speed_{label}", | |
| ) | |
| cls.plot_multi_experiment_distance_from_point( | |
| experiments, title=f"{label} distance from cue", | |
| filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "distance", save_fig_prefix=f"distance_{label}", | |
| ) | |
| cls.plot_multi_experiment_categorical_percent( | |
| experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args, | |
| categoricals=cue_categoricals, unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, x_label="", | |
| save_fig_root=figure_root / sub_dir / "side" / "side", | |
| save_fig_prefix=f"side_{label}", | |
| ) | |
| cls.plot_multi_experiment_categorical_percent( | |
| experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}", | |
| measure="motion_index_freezing", | |
| measure_options=freezing_index_meas_args, | |
| categoricals=freeze_categoricals, unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, x_label="", | |
| save_fig_root=figure_root / sub_dir / "motion_index_freezing" / "freezing", | |
| save_fig_prefix=f"freezing_{label}", | |
| ) | |
| cls.plot_multi_experiment_categorical_percent( | |
| experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}", | |
| measure="speed_freezing", | |
| measure_options=freezing_speed_meas_args, | |
| categoricals=freeze_categoricals, unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, x_label="", | |
| save_fig_root=figure_root / sub_dir / "speed_freezing" / "freezing", | |
| save_fig_prefix=f"freezing_{label}", | |
| ) | |
| for plot_type in ["scatter", "count", "mean", "duration"]: | |
| cls.plot_multi_experiment_bouts( | |
| experiments, "speed_freezing", frame_rate=frame_rate, plot_type=plot_type, | |
| filter_args=filter_args, group_label=group_label, unit=cat_unit, | |
| save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"speed_freezing_{label}", | |
| **bouts_parameters["speed_freezing"], | |
| ) | |
| cls.plot_multi_experiment_bouts( | |
| experiments, "motion_index_freezing", frame_rate=frame_rate, plot_type=plot_type, | |
| filter_args=filter_args, group_label=group_label, unit=cat_unit, | |
| save_fig_root=figure_root / sub_dir / "bouts" / plot_type, | |
| save_fig_prefix=f"motion_index_freezing_{label}", | |
| **bouts_parameters["motion_index_freezing"], | |
| ) | |
| cls.plot_multi_experiment_bouts( | |
| experiments, "motion_index_dart", frame_rate=frame_rate, plot_type=plot_type, | |
| filter_args=filter_args, group_label=group_label, unit=cat_unit, | |
| save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"motion_index_dart_{label}", | |
| **bouts_parameters["motion_index_dart"], | |
| ) | |
| cls.plot_multi_experiment_bouts( | |
| experiments, "speed_dart", frame_rate=frame_rate, plot_type=plot_type, | |
| filter_args=filter_args, group_label=group_label, unit=cat_unit, | |
| save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"speed_dart_{label}", | |
| **bouts_parameters["speed_dart"], | |
| ) | |
| if not filter_args: | |
| return | |
| cls.plot_multi_experiment_merged_by_period_categorical_percent( | |
| experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args, | |
| categoricals=cue_categoricals, x_label="", unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_period", | |
| save_fig_prefix=f"side_merged_by_period_{label}", | |
| ) | |
| cls.plot_multi_experiment_merged_by_group_categorical_percent( | |
| experiments, title=f"Total time spent in cue side", show_titles=False, show_legend=True, | |
| measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit, | |
| categoricals=cue_categoricals, x_label="Cue side", | |
| filter_args=filter_args, only_categorical="Cue-side", show_xlabel=False, | |
| save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_group", | |
| save_fig_prefix=f"side_merged_by_group_{label}", group_label=group_label, | |
| ) | |
| cls.plot_multi_experiment_merged_by_period_categorical_percent( | |
| experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}", | |
| measure="motion_index_freezing", | |
| measure_options=freezing_index_meas_args, | |
| categoricals=freeze_categoricals, x_label="", unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "motion_index_freezing" / | |
| "freezing_merged_by_period", | |
| save_fig_prefix=f"freezing_merged_by_period_{label}", | |
| ) | |
| cls.plot_multi_experiment_merged_by_group_categorical_percent( | |
| experiments, | |
| title=f"Total time spent freezing (motion index). Threshold={freezing_index_meas_args['threshold']}", | |
| show_titles=False, show_legend=True, | |
| measure="motion_index_freezing", measure_options=freezing_index_meas_args, unit=cat_unit, | |
| categoricals=freeze_categoricals, x_label="Freezing", | |
| filter_args=filter_args, only_categorical="Freezing", show_xlabel=False, | |
| save_fig_root=figure_root / sub_dir / "motion_index_freezing" / | |
| "freezing_merged_by_group", | |
| save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label, | |
| ) | |
| cls.plot_multi_experiment_merged_by_period_categorical_percent( | |
| experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}", | |
| measure="speed_freezing", | |
| measure_options=freezing_speed_meas_args, | |
| categoricals=freeze_categoricals, x_label="", unit=cat_unit, | |
| filter_args=filter_args, group_label=group_label, | |
| save_fig_root=figure_root / sub_dir / "speed_freezing" / | |
| "freezing_merged_by_period", | |
| save_fig_prefix=f"freezing_merged_by_period_{label}", | |
| ) | |
| cls.plot_multi_experiment_merged_by_group_categorical_percent( | |
| experiments, | |
| title=f"Total time spent freezing (speed). Threshold={freezing_speed_meas_args['threshold']}", | |
| show_titles=False, show_legend=True, | |
| measure="speed_freezing", measure_options=freezing_speed_meas_args, unit=cat_unit, | |
| categoricals=freeze_categoricals, x_label="Freezing", | |
| filter_args=filter_args, only_categorical="Freezing", show_xlabel=False, | |
| save_fig_root=figure_root / sub_dir / "speed_freezing" / | |
| "freezing_merged_by_group", | |
| save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label, | |
| ) | |
| @classmethod | |
| def print_measure_hist( | |
| cls, experiments: list["Experiment"], percentile: float, downsample_factor: int = 1, | |
| hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (), | |
| post_segment: tuple[float, float] = (), | |
| ) -> None: | |
| if sum(map(bool, [hab_segment, trial_segment, post_segment])) > 1: | |
| raise ValueError("Only one of hab_segment, trial_segment, post_segment may be specified") | |
| print(f"Downsampling measures at {downsample_factor} from original frame rate") | |
| print("Motion index counts between number to next larger number:") | |
| pprint.pprint(Experiment.count_motion_range( | |
| experiments, measure="motion_index", downsample_factor=downsample_factor, | |
| hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment, | |
| )) | |
| print("Speed counts between number to next larger number:") | |
| pprint.pprint(Experiment.count_motion_range( | |
| experiments, measure="speed", downsample_factor=downsample_factor, | |
| hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment, | |
| )) | |
| data = [exp.pos_data for exp in experiments] | |
| if hab_segment or trial_segment or post_segment: | |
| data = [ | |
| exp.extract_subject_period( | |
| hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment | |
| ) | |
| for exp in experiments | |
| ] | |
| motion_index_percentile = Experiment.percentile_motion_range( | |
| data, percentile, measure="motion_index", downsample_factor=downsample_factor, | |
| ) | |
| speed_percentile = Experiment.percentile_motion_range( | |
| data, percentile, measure="speed", downsample_factor=downsample_factor, | |
| ) | |
| print(f"Motion index at {percentile} percentile between 0:30 - 4:30 pre-trial is {motion_index_percentile}") | |
| print(f"Speed at {percentile} percentile between 0:30 - 4:30 pre-trial is {speed_percentile}") | |
| @dataclass | |
| class SubjectPos: | |
| filename: str | Path | |
| times: np.ndarray | |
| track: np.ndarray = field(repr=False) | |
| motion_index: np.ndarray = field(repr=False) | |
| cm_per_px: float | None = None | |
| _point_marker: dict = field(default_factory=dict, init=False, repr=False) | |
| @property | |
| def min_x(self): | |
| return np.min(self.track[:, 0]) | |
| @property | |
| def min_y(self): | |
| return np.min(self.track[:, 1]) | |
| @property | |
| def max_x(self): | |
| return np.max(self.track[:, 0]) | |
| @property | |
| def max_y(self): | |
| return np.max(self.track[:, 1]) | |
| @property | |
| def duration(self) -> float: | |
| if not len(self.times): | |
| return 0 | |
| return (self.times[-1] - self.times[0]).item() | |
| @classmethod | |
| def parse_csv_track( | |
| cls, filename: Path, subject_name: str = "mouse", frame_rate: float = 30., cm_per_px: float | None = None, | |
| ) -> "SubjectPos": | |
| track = [] | |
| times = [] | |
| index = [] | |
| print(f"Reading {filename}") | |
| with open(filename, "r") as fh: | |
| reader = csv.reader(fh) | |
| header = list(next(reader)) | |
| last_t = None | |
| last_frame_t = None | |
| for i, row in enumerate(reader): | |
| frame, instance, cx, cy, index_val, t1, t2 = row | |
| if header[-1] == "real_timestamp_sec": | |
| t = t2 | |
| elif header[-2] == "real_timestamp_sec": | |
| t = t1 | |
| else: | |
| raise ValueError("Cannot find the real_timestamp_sec column") | |
| if instance != subject_name: | |
| continue | |
| track.append((float(cx), float(cy))) | |
| index.append(float(index_val)) | |
| t = float(t) | |
| frame_t = float(frame) / frame_rate | |
| if not (frame_t * 0.95 <= t <= frame_t * 1.05): | |
| raise ValueError(f"{filename} frame rate of {frame_rate} does not match time stamp at \"{row}\"") | |
| if last_t is not None and last_t >= t or last_frame_t is not None and last_frame_t >= frame_t: | |
| raise ValueError(f"{filename} sequential times are not increasing \"{row}\"") | |
| last_t = t | |
| last_frame_t = frame_t | |
| times.append(frame_t) | |
| times = np.array(times) | |
| return SubjectPos( | |
| filename=filename, times=np.array(times), track=np.array(track), motion_index=np.array(index), | |
| cm_per_px=cm_per_px, | |
| ) | |
| def extract_range(self, t_start: float | None = None, t_end: float | None = None) -> "SubjectPos": | |
| if t_start is None: | |
| t_start = self.times[0] | |
| if t_end is None: | |
| t_end = self.times[-1] + 1 | |
| i_s = np.sum(self.times < t_start) | |
| i_e = np.sum(self.times <= t_end) | |
| return SubjectPos( | |
| filename=self.filename, times=self.times[i_s:i_e], track=self.track[i_s:i_e], | |
| motion_index=self.motion_index[i_s:i_e], cm_per_px=self.cm_per_px, | |
| ) | |
| @classmethod | |
| def calculate_bouts( | |
| cls, data: np.ndarray, threshold: float, above_threshold: bool, frame_rate: float, | |
| refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0, | |
| ) -> list[tuple[float, float]]: | |
| final_frame_rate = frame_rate / downsample_factor | |
| refractory_n = int(round(refractory_period * final_frame_rate)) | |
| if refractory_period: | |
| refractory_n = max(refractory_n, 1) | |
| if len(data) <= 50: | |
| return [] | |
| if downsample_factor != 1: | |
| data = decimate(data, downsample_factor) | |
| if above_threshold: | |
| met_threshold = data >= threshold | |
| else: | |
| met_threshold = data <= threshold | |
| last_marked = None | |
| start_i = None | |
| bouts = [] | |
| for i, state in enumerate(met_threshold): | |
| if state: | |
| last_marked = i | |
| if start_i is None: | |
| start_i = i | |
| else: | |
| continue | |
| else: | |
| if start_i is None: | |
| continue | |
| if not refractory_n: | |
| bouts.append((start_i, i - start_i)) | |
| start_i = None | |
| else: | |
| if i - last_marked >= refractory_n: | |
| bouts.append((start_i, last_marked - start_i + 1)) | |
| start_i = None | |
| if start_i is not None: | |
| bouts.append((start_i, last_marked - start_i + 1)) | |
| if min_bout_frames: | |
| bouts = [b for b in bouts if b[1] >= min_bout_frames] | |
| bouts = [(t / final_frame_rate, dur / final_frame_rate) for t, dur in bouts] | |
| return bouts | |
| def calculate_occupancy( | |
| self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True, | |
| ) -> None: | |
| n = self.track.shape[0] | |
| if not n: | |
| return | |
| grid_width, grid_height = occupancy.shape | |
| frame_proportion = 1 | |
| if frame_normalize: | |
| frame_proportion = 1 / n | |
| for i in range(n): | |
| if self.track[i, 0] < 0 or self.track[i, 1] < 0: | |
| continue | |
| if pos_to_index is None: | |
| x, y = self.track[i, :] | |
| else: | |
| x, y = pos_to_index(*self.track[i, :], grid_width, grid_height) | |
| x = int(min(x, grid_width)) | |
| y = int(min(y, grid_height)) | |
| occupancy[x, y] += frame_proportion | |
| def calculate_speed(self, pad_to_input: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| track = self.track | |
| valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0) | |
| track = track[valid, :] | |
| times = self.times[valid] | |
| elapsed = times[1:] - times[:-1] | |
| if np.any(elapsed <= 0): | |
| raise ValueError(f"Found sequential time stamps whose difference is not positive {self}") | |
| cm_per_px = self.cm_per_px or 1 | |
| dist = np.sqrt(np.sum(np.square(track[1:, :] - track[:-1, :]), axis=1, keepdims=False)) | |
| dist = dist * cm_per_px | |
| speed = dist / elapsed | |
| if pad_to_input: | |
| return ( | |
| np.pad(times[1:], (1, 0), "edge"), | |
| np.pad(speed, (1, 0), "edge"), | |
| np.pad(dist, (1, 0), "edge"), | |
| np.pad(elapsed, (1, 0), "edge") | |
| ) | |
| return times[1:], speed, dist, elapsed | |
| @classmethod | |
| def plot_occupancy_data( | |
| cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes, | |
| gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True, | |
| x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
| color_bar: bool = True, cm_per_px: float = 1, | |
| ): | |
| if gaussian_sigma: | |
| occupancy = gaussian_filter(occupancy, gaussian_sigma) | |
| if scale_to_one: | |
| max_val = occupancy.max() | |
| if max_val: | |
| occupancy /= max_val | |
| else: | |
| occupancy[:] = 0 | |
| n_rows, n_cols = occupancy.T.shape | |
| extent = (-0.5 * cm_per_px, (n_cols - 0.5) * cm_per_px, (n_rows - 0.5) * cm_per_px, -0.5 * cm_per_px) | |
| im = ax.imshow( | |
| occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc", | |
| interpolation_stage="data", vmax=intensity_limit or None, extent=extent, | |
| ) | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| if title: | |
| ax.set_title(title) | |
| if color_bar: | |
| divider = make_axes_locatable(ax) | |
| cax = divider.append_axes('right', size='5%', pad=0.05) | |
| fig.colorbar(im, cax=cax, orientation='vertical') | |
| def plot_occupancy( | |
| self, grid_width: int, grid_height: int, | |
| pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True, | |
| scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
| color_bar: bool = True, | |
| ): | |
| occupancy = np.zeros((grid_width, grid_height)) | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| self.calculate_occupancy(occupancy, pos_to_index, frame_normalize) | |
| cm_per_px = self.cm_per_px or 1 | |
| self.plot_occupancy_data( | |
| occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar, | |
| cm_per_px=cm_per_px, | |
| ) | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_motion_index( | |
| self, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| x_label: str = "Time (min)", y_label: str = "Motion index", | |
| ): | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| cm_per_px = self.cm_per_px or 1 | |
| motion_index = self.motion_index * cm_per_px | |
| valid = motion_index >= 0 | |
| times = self.times[valid] | |
| if len(times): | |
| ax.plot(times / 60 - self.times[0] / 60, motion_index[valid], **SubjectPos._point_marker) | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_speed( | |
| self, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| x_label: str = "Time (min)", y_label: str = "Speed", | |
| ): | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| times, speed, _, _ = self.calculate_speed() | |
| if len(times): | |
| ax.plot(times / 60 - self.times[0] / 60, speed, **SubjectPos._point_marker) | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def transform_to_categorical_pos( | |
| self, position_to_side: Callable, categoricals: tuple[str, ...], | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| track = self.track[:-1, :] | |
| valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0) | |
| pos_categoricals_name = [position_to_side(*p) for p in track[valid, :]] | |
| categoricals = {name: i for i, name in enumerate(categoricals)} | |
| categoricals_index = np.array([categoricals[n] for n in pos_categoricals_name]) | |
| time_diff = self.times[1:] - self.times[:-1] | |
| assert np.all(time_diff >= 0) | |
| time_diff = time_diff[valid] | |
| times = self.times[:-1][valid] | |
| return times, time_diff, categoricals_index | |
| def plot_categorical_values( | |
| self, times, categoricals_index, categoricals: tuple[str, ...], | |
| fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| ): | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| ax.plot(times / 60 - self.times[0] / 60, categoricals_index, "*", alpha=.3, ms=5) | |
| ax.set_yticks(np.arange(len(categoricals)), categoricals) | |
| ax.set_xlabel("Time (min)") | |
| ax.set_ylabel("Side of box relative to teaball") | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def plot_distance_from_point( | |
| self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| x_label: str = "Time (min)", y_label: str = "Distance", | |
| ): | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0) | |
| point = np.array(point_xy)[None, :] | |
| distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1)) | |
| cm_per_px = self.cm_per_px or 1 | |
| distance = distance * cm_per_px | |
| times = self.times[valid] | |
| if len(times): | |
| ax.plot(times / 60 - self.times[0] / 60, distance, **SubjectPos._point_marker) | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| def calculate_measure_bout( | |
| self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, | |
| refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0, | |
| ) -> list[tuple[float, float]]: | |
| match measure: | |
| case "motion_index_freezing": | |
| cm_per_px = self.cm_per_px or 1 | |
| index = self.motion_index * cm_per_px | |
| above = False | |
| case "motion_index_dart": | |
| cm_per_px = self.cm_per_px or 1 | |
| index = self.motion_index * cm_per_px | |
| above = True | |
| case "speed_freezing": | |
| times, index, _, time_diff = self.calculate_speed() | |
| above = False | |
| case "speed_dart": | |
| times, index, _, time_diff = self.calculate_speed() | |
| above = True | |
| case _: | |
| assert False | |
| bouts = self.calculate_bouts( | |
| index, threshold, above_threshold=above, frame_rate=frame_rate, | |
| refractory_period=refractory_period, | |
| downsample_factor=downsample_factor, min_bout_frames=min_bout_frames, | |
| ) | |
| return bouts | |
| def plot_bout( | |
| self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, refractory_period: float = 0, | |
| downsample_factor: int = 1, min_bout_frames: int = 0, | |
| plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter", | |
| fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
| save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
| x_label: str = "Time (min)", y_label: str = "Duration (s)", norm_time: bool = False, | |
| ) -> list[tuple[float, float]]: | |
| show_plot = ax is None | |
| if ax is None: | |
| assert fig is None | |
| fig, ax = plt.subplots() | |
| else: | |
| assert fig is not None | |
| bouts = self.calculate_measure_bout( | |
| measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames | |
| ) | |
| match plot_type: | |
| case "scatter": | |
| if len(bouts): | |
| ax.plot([b[0] / 60 for b in bouts], [b[1] for b in bouts], **SubjectPos._point_marker) | |
| if len(self.times): | |
| ax.set_xlim(0, (self.times[-1].item() - self.times[0].item()) / 60) | |
| case "count": | |
| ax.bar([""], [len(bouts) / (self.duration if norm_time else 1)]) | |
| case "mean": | |
| val = sum(b[1] for b in bouts) / len(bouts) if bouts else 0 | |
| ax.bar([0], [val]) | |
| ax.plot([0] * len(bouts), [b[1] for b in bouts], "k.") | |
| ax.set_xticks([0], [""]) | |
| case "duration": | |
| norm_factor = (100 / self.duration) if norm_time else 1 | |
| ax.bar([0], [sum(b[1] for b in bouts) * norm_factor]) | |
| ax.plot([0] * len(bouts), np.cumsum([b[1] for b in bouts]) * norm_factor, "k.") | |
| ax.set_xticks([0], [""]) | |
| if x_label: | |
| ax.set_xlabel(x_label) | |
| if y_label: | |
| ax.set_ylabel(y_label) | |
| if show_plot: | |
| save_or_show(save_fig_root, save_fig_prefix) | |
| return bouts | |
| class ExperimentParameters: | |
| root: Path | |
| figure_root: Path | |
| data_root: Path | |
| video_root: Path | |
| image_root: Path | |
| metadata: list[str] | |
| filename_fmt: str | |
| experiment_triplet_name: tuple[str, str, str] | |
| bouts_parameters: dict[str, dict[str, float]] | |
| thresholds: dict[str, float] | |
| frame_rate: float | |
| times_of_interest: dict[str, float] | |
| image_size: tuple[int, int] | |
| output_fig_h: int | |
| title_root_fmt: str | |
| subject_title_fmt: str | |
| video_filename_fmt: str | |
| chopped_pre_times: list[tuple[int, int]] | |
| occupancy_meas_args: dict | |
| freezing_index_meas_args: dict | |
| freezing_speed_meas_args: dict | |
| # first element of cue, should be the label for the side of the cue | |
| cue_categoricals: tuple[str, str] | |
| freeze_categoricals: tuple[str, str] | |
| experiment_groups: dict[str, tuple[str, ...]] | |
| individual_plot_groups: tuple[str, ...] | |
| multi_plot_groups: list[dict[str, tuple[str, ...]]] | |
| def get_inventory_csv(self, batch: int | None = None) -> Path: | |
| raise NotImplementedError | |
| def get_tracked_data_root(self, batch: int | None = None) -> Path: | |
| raise NotImplementedError | |
| def read_all_experiments(self) -> list[Experiment]: | |
| raise NotImplementedError | |
| def get_cm_per_px(self, **selectors) -> float: | |
| raise NotImplementedError | |
| def get_cue_point(self, **selectors) -> tuple[float | str, float | str]: | |
| raise NotImplementedError | |
| def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]: | |
| raise NotImplementedError | |
| class YidanParameters(ExperimentParameters): | |
| root = Path(r'H:\yidan_2025_behavior') | |
| figure_root = root / "results" / "figures" | |
| data_root = root / "results" / "data" | |
| video_root = root / "results" / "video" | |
| image_root = root / "results" / "video_images" | |
| metadata = ["date", "subject", "strain", "condition", "cell_label", "exposure", "sex"] | |
| filename_fmt = "{date}_{subject}_top_tracked.csv" | |
| experiment_triplet_name = "Pre Trial", "Trial", "Post Trial" | |
| bouts_parameters = { | |
| "motion_index_freezing": { | |
| "threshold": 0.001278, # %2.5 | |
| "refractory_period": 0.25, | |
| "downsample_factor": 1, | |
| "min_bout_frames": 2, | |
| }, | |
| "speed_freezing": { | |
| "threshold": 0.1180, # %5 | |
| "refractory_period": 0.25, | |
| "downsample_factor": 1, | |
| "min_bout_frames": 2, | |
| }, | |
| "motion_index_dart": { | |
| "threshold": 0.1242, # %90 | |
| "refractory_period": 0.8, | |
| "downsample_factor": 6, | |
| "min_bout_frames": 2, | |
| }, | |
| "speed_dart": { | |
| "threshold": 12.85, # %95 | |
| "refractory_period": 0.8, | |
| "downsample_factor": 6, | |
| "min_bout_frames": 2, | |
| } | |
| } | |
| thresholds = { | |
| "speed_freezing": 0.1180, # %5 | |
| "motion_index_freezing": 0.001278, # %2.5 | |
| "speed_dart": 14.32, # %95 | |
| "motion_index_dart": 0.1391, # %90 | |
| } | |
| frame_rate = 24 | |
| times_of_interest = { | |
| "pre_offset": 3 * 60, | |
| "pre_duration": 2 * 60, | |
| "trial_offset": 0, | |
| "trial_duration": 2 * 60, | |
| "post_offset": 0, | |
| "post_duration": 2 * 60, | |
| } | |
| image_size = 960, 600 | |
| output_fig_h = image_size[1] // 2 | |
| output_fig_h = (output_fig_h // 2) * 2 | |
| title_root_fmt = "{subject}-{exposure}-{date}" | |
| subject_title_fmt = "{subject}_exposure-{exposure}_{date}" | |
| video_filename_fmt = "{date}_{subject}_top.mp4" | |
| chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(10)] | |
| occupancy_meas_args = {"split_horizontally": True} | |
| freezing_index_meas_args = {"threshold": thresholds["motion_index_freezing"], "freeze_name": "Freezing"} | |
| freezing_speed_meas_args = {"threshold": thresholds["speed_freezing"], "freeze_name": "Freezing"} | |
| cue_categoricals = "Cue-side", "Far-side" | |
| freeze_categoricals = "Freezing", "Non-freezing" | |
| experiment_groups = { | |
| "condition": ("Blank", "TMT", "2MBA", "IAMM"), | |
| "strain": ("TRAP1",), | |
| "exposure": ("1", "2"), | |
| "sex": ("M", "F"), | |
| } | |
| individual_plot_groups = "condition", "strain", "exposure" | |
| multi_plot_groups = [ | |
| { | |
| "individual_plot_per": ("exposure", "strain"), | |
| "groups_in_single_plot": ("condition", ), | |
| }, | |
| { | |
| "individual_plot_per": ("condition", "strain"), | |
| "groups_in_single_plot": ("exposure", ), | |
| }, | |
| { | |
| "individual_plot_per": ("condition", "strain"), | |
| "groups_in_single_plot": ("exposure", "sex"), | |
| }, | |
| ] | |
| def get_inventory_csv(self, batch: int | None = None) -> Path: | |
| if batch is None: | |
| raise NotImplementedError | |
| return self.root / f"behavioral data timestamp - b{batch} timestamps.csv" | |
| def get_tracked_data_root(self, batch: int | None = None) -> Path: | |
| if batch is None: | |
| raise NotImplementedError | |
| return self.root / "tracking" / f"batch{batch}" | |
| def read_all_experiments(self) -> list[Experiment]: | |
| all_experiments = [] | |
| for num in (3, 5, 6, 7): | |
| csv_experiment_times = self.root / f"behavioral data timestamp - b{num} timestamps.csv" | |
| tracking_data_root = self.root / "tracking" / f"batch{num}" | |
| json_data_root = self.root / "tracking" / f"batch{num}" / "json" | |
| experiments = Experiment.read_experiments_from_csv_inventory( | |
| csv_experiment_times, self.metadata, self.filename_fmt, self.title_root_fmt | |
| ) | |
| all_experiments.extend(experiments) | |
| for experiment in experiments: | |
| if num == 3: | |
| experiment.parse_box_metadata(json_data_root, cm_per_px=self.get_cm_per_px(**experiment.metadata)) | |
| else: | |
| experiment.set_box_metadata( | |
| *self.get_box_size_cc(**experiment.metadata), image_size=self.image_size, | |
| cm_per_px=self.get_cm_per_px(**experiment.metadata), | |
| ) | |
| experiment.read_pos_track( | |
| tracking_data_root, **self.times_of_interest, frame_rate=self.frame_rate, | |
| ) | |
| experiment.chop_pos_track(hab_segments=self.chopped_pre_times) | |
| experiment.cue_point = self.get_cue_point(**experiment.metadata) | |
| # all experiments share the same scaling, even if slightly different sizes | |
| Experiment.enlarge_canvas(all_experiments) | |
| return all_experiments | |
| def get_cm_per_px(self, **selectors) -> float: | |
| return 20 / 855.1 * 2 | |
| def get_cue_point(self, **selectors) -> tuple[float | str, float | str]: | |
| return "left", "bottom" | |
| def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]: | |
| return 200, 132, 784, 486 | |
| class CynthiaParameters(ExperimentParameters): | |
| root = Path(r'H:\cynthia') | |
| figure_root = root / "results" / "figures" | |
| data_root = root / "results" / "data" | |
| video_root = root / "results" / "video" | |
| image_root = root / "results" / "video_images" | |
| metadata = ["date", "subject", "stage", "condition", "cue", "sex"] | |
| filename_fmt = "{subject}_{stage}_{date}_tracked.csv" | |
| experiment_triplet_name = "Pre Trial", "Trial", "Post Trial" | |
| bouts_parameters = { | |
| "motion_index_freezing": { | |
| "threshold": 0.001278, # %2.5 | |
| "refractory_period": 0.25, | |
| "downsample_factor": 1, | |
| "min_bout_frames": 2, | |
| }, | |
| "speed_freezing": { | |
| "threshold": 0.1180, # %5 | |
| "refractory_period": 0.25, | |
| "downsample_factor": 1, | |
| "min_bout_frames": 2, | |
| }, | |
| "motion_index_dart": { | |
| "threshold": 0.1242, # %90 | |
| "refractory_period": 0.8, | |
| "downsample_factor": 6, | |
| "min_bout_frames": 2, | |
| }, | |
| "speed_dart": { | |
| "threshold": 12.85, # %95 | |
| "refractory_period": 0.8, | |
| "downsample_factor": 6, | |
| "min_bout_frames": 2, | |
| } | |
| } | |
| thresholds = { | |
| "speed_freezing": 0.1180, # %5 | |
| "motion_index_freezing": 0.001278, # %2.5 | |
| "speed_dart": 14.32, # %95 | |
| "motion_index_dart": 0.1391, # %90 | |
| } | |
| frame_rate = 24 | |
| times_of_interest = { | |
| "pre_duration": -28, | |
| "post_duration": 28 | |
| } | |
| image_size = 640, 480 | |
| output_fig_h = image_size[1] // 2 | |
| output_fig_h = (output_fig_h // 2) * 2 | |
| title_root_fmt = "{subject}-{stage}-{condition}-{cue}" | |
| subject_title_fmt = "{subject}_{stage}_{condition}_{cue}" | |
| video_filename_fmt = "{subject}_{stage}_{date}.mp4" | |
| chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(8)] | |
| occupancy_meas_args = {"split_horizontally": False} | |
| freezing_index_meas_args = {"threshold": thresholds["motion_index_freezing"], "freeze_name": "Freezing"} | |
| freezing_speed_meas_args = {"threshold": thresholds["speed_freezing"], "freeze_name": "Freezing"} | |
| cue_categoricals = "Cue-side", "Far-side" | |
| freeze_categoricals = "Freezing", "Non-freezing" | |
| experiment_groups = { | |
| "condition": ("cued", "control", "back"), | |
| "cue": ("odor",), | |
| "stage": ("test", "test_24hr"), | |
| } | |
| individual_plot_groups = "condition", "cue", "stage" | |
| multi_plot_groups = [ | |
| { | |
| "individual_plot_per": ("condition", "cue"), | |
| "groups_in_single_plot": ("stage", ), | |
| }, | |
| { | |
| "individual_plot_per": ("stage", "cue"), | |
| "groups_in_single_plot": ("condition", ), | |
| }, | |
| { | |
| "individual_plot_per": ("cue", ), | |
| "groups_in_single_plot": ("condition", "stage"), | |
| }, | |
| ] | |
| def __init__(self, batch_names: list[str], cue_in_corner: bool, test_is_train_chamber: bool): | |
| self.cue_in_corner = cue_in_corner | |
| self.test_is_train_chamber = test_is_train_chamber | |
| self.batch_names = batch_names | |
| def get_inventory_csv(self, batch: int | None = None) -> Path: | |
| if batch is None: | |
| return self.root / f"{self.batch_names[0]}.csv" | |
| return self.root / f"{self.batch_names[batch]}.csv" | |
| def get_tracked_data_root(self, batch: int | None = None) -> Path: | |
| if batch is None: | |
| return self.root / self.batch_names[0] | |
| return self.root / self.batch_names[batch] | |
| def read_all_experiments(self) -> list[Experiment]: | |
| all_experiments = [] | |
| for num in range(len(self.batch_names)): | |
| experiments = Experiment.read_experiments_from_csv_inventory( | |
| self.get_inventory_csv(num), self.metadata, self.filename_fmt, self.title_root_fmt | |
| ) | |
| all_experiments.extend(experiments) | |
| for experiment in experiments: | |
| experiment.set_box_metadata( | |
| *self.get_box_size_cc(**experiment.metadata), image_size=self.image_size, | |
| cm_per_px=self.get_cm_per_px(**experiment.metadata), | |
| ) | |
| experiment.read_pos_track( | |
| self.get_tracked_data_root(num), **self.times_of_interest, frame_rate=self.frame_rate, | |
| ) | |
| experiment.chop_pos_track(hab_segments=self.chopped_pre_times) | |
| experiment.cue_point = self.get_cue_point(**experiment.metadata) | |
| if self.test_is_train_chamber: | |
| Experiment.enlarge_canvas(all_experiments) | |
| else: | |
| # align experiments that share the same scaling, even if slightly different sizes | |
| for group in ["train", "test"]: | |
| experiments = [e for e in all_experiments if e.metadata["stage"] == group] | |
| if experiments: | |
| Experiment.enlarge_canvas(experiments) | |
| return all_experiments | |
| def get_cm_per_px(self, **selectors) -> float: | |
| if self.test_is_train_chamber: | |
| return 15 / 258.86 | |
| if selectors["stage"] == "train": | |
| return 15 / 258.86 | |
| return 15 / 200.01 | |
| def get_cue_point(self, **selectors) -> tuple[float | str, float | str]: | |
| if self.test_is_train_chamber: | |
| if self.cue_in_corner: | |
| return 207, 52 | |
| return 345, 80 | |
| if self.cue_in_corner: | |
| if selectors["stage"] == "train": | |
| return 207, 52 | |
| return 446, 429 | |
| if selectors["stage"] == "train": | |
| return 345, 80 | |
| return 330, 430 | |
| def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]: | |
| # box sizes must be similar if averaging occupancy across them | |
| train = 74, 32, 593, 480 | |
| test = 133, 0, 550, 466 | |
| if self.test_is_train_chamber: | |
| return train | |
| if selectors["stage"] == "train": | |
| return train | |
| return test | |
| def run_exports( | |
| parameters: ExperimentParameters, export_csv: bool = True, export_subjects: bool = True, | |
| export_groups: bool = True, export_multi_groups: bool = True, cat_unit: Literal["times", "percent"]="times", | |
| ): | |
| SubjectPos._point_marker = {"marker": ".", "markersize": 1.5, "linestyle": ""} | |
| Experiment.triplet_name = parameters.experiment_triplet_name | |
| experiments = parameters.read_all_experiments() | |
| if export_csv: | |
| Experiment.export_experiments_csv( | |
| experiments, parameters.data_root, parameters.occupancy_meas_args, parameters.cue_categoricals, | |
| parameters.freezing_index_meas_args, parameters.freezing_speed_meas_args, parameters.freeze_categoricals, | |
| cat_unit, True, parameters.frame_rate, parameters.bouts_parameters, | |
| ) | |
| if export_subjects: | |
| Experiment.export_single_subject_figures( | |
| experiments, parameters.figure_root, parameters.occupancy_meas_args, cat_unit, | |
| parameters.subject_title_fmt, parameters.frame_rate, parameters.bouts_parameters, | |
| ) | |
| individual_group_names = parameters.individual_plot_groups | |
| if export_groups: | |
| individual_group_items = [parameters.experiment_groups[g] for g in individual_group_names] | |
| for group_values in product(*individual_group_items): | |
| named_values = {name: value for name, value in zip(individual_group_names, group_values)} | |
| experiments_ = Experiment.filter(experiments, **named_values) | |
| label = "-".join(group_values) | |
| Experiment.export_multi_experiment_figures( | |
| experiments_, parameters.figure_root, parameters.occupancy_meas_args, | |
| cat_unit, label, parameters.cue_categoricals, parameters.freezing_index_meas_args, | |
| parameters.freezing_speed_meas_args, parameters.freeze_categoricals, parameters.frame_rate, | |
| parameters.bouts_parameters, None, "", "grouped", | |
| ) | |
| if export_multi_groups: | |
| items = [] | |
| for split in parameters.multi_plot_groups: | |
| individual_plot_names = split["individual_plot_per"] | |
| groups_in_single_plot_names = split["groups_in_single_plot"] | |
| individual_plot_values = [parameters.experiment_groups[n] for n in individual_plot_names] | |
| groups_in_single_plot_values = [parameters.experiment_groups[n] for n in groups_in_single_plot_names] | |
| filter_args = [] | |
| for group_values in product(*groups_in_single_plot_values): | |
| named_values = {name: value for name, value in zip(groups_in_single_plot_names, group_values)} | |
| filter_args.append(named_values) | |
| for group_values in product(*individual_plot_values): | |
| named_values = {name: value for name, value in zip(individual_plot_names, group_values)} | |
| items.append(( | |
| named_values, | |
| filter_args, | |
| "-".join(map(str, group_values)), | |
| "-".join(groups_in_single_plot_names), | |
| )) | |
| for groups, filter_args, label, sub_dir in items: | |
| experiments_ = Experiment.filter(experiments, **groups) | |
| bracket_group_names = ["{{{" + f + "}}}" for f in filter_args[0].keys()] | |
| group_label = "$\\bf" + "$\n$\\bf".join(bracket_group_names) + "$\n\n" | |
| Experiment.export_multi_experiment_figures( | |
| experiments_, parameters.figure_root, parameters.occupancy_meas_args, | |
| cat_unit, label, parameters.cue_categoricals, parameters.freezing_index_meas_args, | |
| parameters.freezing_speed_meas_args, parameters.freeze_categoricals, parameters.frame_rate, | |
| parameters.bouts_parameters, filter_args, group_label, f"grouped_single_figure/by_{sub_dir}" | |
| ) | |
| def generate_video( | |
| parameters: ExperimentParameters, batch: int | None, time_window_padding: float, ffmpeg: str | Path, | |
| temp_images: bool, above_measure: Literal["motion_index", "speed"] | None, | |
| below_measure: Literal["motion_index", "speed"] | None, | |
| **selectors: str, | |
| ): | |
| Experiment.triplet_name = parameters.experiment_triplet_name | |
| csv_experiment_times = parameters.get_inventory_csv(batch) | |
| tracking_data_root = parameters.get_tracked_data_root(batch) | |
| experiment = Experiment.read_experiment_from_csv_inventory( | |
| csv_experiment_times, parameters.metadata, parameters.filename_fmt, | |
| parameters.title_root_fmt, **selectors, | |
| ) | |
| experiment.set_box_metadata( | |
| *parameters.get_box_size_cc(**experiment.metadata), image_size=parameters.image_size, | |
| cm_per_px=parameters.get_cm_per_px(**experiment.metadata), | |
| ) | |
| experiment.read_pos_track( | |
| tracking_data_root, **parameters.times_of_interest, frame_rate=parameters.frame_rate, | |
| ) | |
| video_filename = tracking_data_root / parameters.video_filename_fmt.format(**experiment.metadata) | |
| experiment.merge_video_with_images( | |
| parameters, parameters.frame_rate, parameters.image_size, parameters.output_fig_h, | |
| time_window_padding, video_filename, ffmpeg, temp_images, above_measure, below_measure, | |
| parameters.bouts_parameters, parameters.thresholds, | |
| ) | |
| if __name__ == "__main__": | |
| sample_parameters = CynthiaParameters( | |
| batch_names=["111025"], cue_in_corner=True, test_is_train_chamber=False, | |
| ) | |
| # sample_parameters = YidanParameters() | |
| # Experiment.print_measure_hist( | |
| # sample_parameters.read_all_experiments(), percentile=90, downsample_factor=1, hab_segment=(30, 270) | |
| # ) | |
| run_exports( | |
| sample_parameters, export_subjects=True, export_csv=False, export_groups=False, export_multi_groups=False, | |
| # cat_unit="percent", | |
| ) | |
| # generate_video( | |
| # sample_parameters, 5, 3, | |
| # r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe", | |
| # False, "motion_index", "speed", date="20241206", subject="1202", | |
| # ) | |
| # generate_video( | |
| # sample_parameters, 6, 3, | |
| # r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe", | |
| # False, "motion_index", "speed", date="20250423", subject="1282", | |
| # ) | |
| # generate_video( | |
| # sample_parameters, 6, 3, | |
| # r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe", | |
| # False, "motion_index", "speed", date="20250321", subject="1283", | |
| # ) | |
| # generate_video( | |
| # sample_parameters, 0, 3, | |
| # r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe", | |
| # False, "motion_index", "speed", date="20251024", subject="2831", | |
| # ) | |
| # pos_data = SubjectPos.parse_csv_track( | |
| # Path(r"D:\code_data\yidan_2025\20251006_fake_mouse_top_tracked.csv"), frame_rate=sample_parameters.frame_rate, | |
| # subject_name="mouse", cm_per_px=sample_parameters.cm_per_px, | |
| # ) | |
| # exp = Experiment(0, 0, 0, 0, 0, 0, "", "", cm_per_px=sample_parameters.cm_per_px) | |
| # exp.pos_data = pos_data | |
| # Experiment.print_measure_hist( | |
| # [exp], percentile=90, downsample_factor=1, hab_segment=(30, 270) | |
| # ) | |
| # Experiment.plot_compare_motion_index_range( | |
| # Path(r"D:\code_data\cynthia_sp_2025\1303_test_20250506_tracked.csv"), | |
| # Path(r"D:\code_data\cynthia_sp_2025\fakemouse_test_20250604_tracked.csv"), | |
| # "Real mouse", | |
| # "Fake mouse", | |
| # 15, | |
| # measure="speed" | |
| # ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment