Skip to content

Instantly share code, notes, and snippets.

@matham
Last active November 17, 2025 21:53
Show Gist options
  • Select an option

  • Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.

Select an option

Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
from dataclasses import dataclass, field
import pprint
from collections.abc import Sequence
from typing import Callable, Union, Literal, Any, Optional
from pathlib import Path
import csv
from scipy.signal import decimate
from tempfile import TemporaryDirectory
import subprocess
import math
import tqdm
import numpy as np
from functools import partial
from itertools import product
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
import json
CAT_MEASURE_TYPE = Literal["occupancy", "motion_index_freezing", "speed_freezing"]
BOUT_MEASURE_TYPE = Literal["motion_index_freezing", "speed_freezing", "motion_index_dart", "speed_dart"]
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8,
height_inch: int = 6):
if save_fig_root:
save_fig_root.mkdir(parents=True, exist_ok=True)
fig = plt.gcf()
fig.set_size_inches(width_inch, height_inch)
fig.tight_layout()
fig.savefig(
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight',
dpi=300
)
plt.close()
else:
plt.tight_layout()
plt.show()
def get_log_bins(data: np.ndarray, n_bins: int) -> np.ndarray:
bins = np.logspace(np.floor(np.log10(np.min(data))), np.ceil(np.log10(np.max(data))), n_bins)
return bins
@dataclass(eq=False)
class Experiment:
pre_start: float
pre_end: float
trial_start: float
trial_end: float
post_start: float
post_end: float
filename_fmt: str
title_root: str
metadata: dict[str, Any] = field(default_factory=dict)
triplet_name: tuple[str, str, str] = "Habituation", "Trial", "Post Trial"
cm_per_px: float | None = None
pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
chopped_pos_data: dict[str, "SubjectPos"] = field(default_factory=dict, init=False, repr=False)
box_center: tuple[int, int] = field(default=None, init=False)
box_size: tuple[int, int] = field(default=None, init=False)
image_size: tuple[int, int] = field(default=None, init=False)
image_offset: tuple[int, int] = field(default=None, init=False)
# how much we need to offset our box center so our box center aligns with the mean box center of all boxes
enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False)
# size of the largest image/canvas we need to fit all experiments, including their offsets
largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False)
cue_point: tuple[float | str, float | str] = field(default=None, init=False, repr=False)
# size of the largest box of all the experiments
def get_header_id_columns(self) -> list[str]:
return list(self.metadata.keys())
def get_id_columns(self) -> list:
return list(self.metadata.values())
def pos_path_filename(self, data_root: Path) -> Path:
return data_root / self.filename_fmt.format(**self.metadata)
def box_metadata_filename(self, data_root: Path) -> Path:
return data_root / "{date}_{subject}_top_000000000.json".format(**self.metadata)
@classmethod
def read_experiment_from_csv_inventory(
cls, csv_filename: Path | str, metadata: list[str], filename_fmt: str,
title_root_fmt: str = "#{subject} ({date})", **selectors: str
) -> Optional["Experiment"]:
experiment = None
with open(csv_filename, "r") as fh:
reader = csv.reader(fh)
header = next(reader)
for row in reader:
metadata_values = {m: row[i] for i, m in enumerate(metadata)}
if not all(metadata_values.get(k) == v for k, v in selectors.items()):
continue
if experiment is not None:
raise ValueError(f"Found more than one row matching {selectors}")
times = []
for t in row[len(metadata):len(metadata) + 6]:
if ":" in t:
min, sec = map(int, t.split(":"))
times.append(min * 60 + sec)
else:
times.append(int(t))
experiment = cls(
metadata=metadata_values, filename_fmt=filename_fmt,
pre_start=times[0], pre_end=times[1],
trial_start=times[2], trial_end=times[3],
post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values),
)
if experiment is None:
raise ValueError(f"Didn't find any row matching {selectors}")
return experiment
@classmethod
def read_experiments_from_csv_inventory(
cls, filename: Path, metadata: list[str], filename_fmt: str,
title_root_fmt: str = "#{subject} ({date})",
) -> list["Experiment"]:
experiments = []
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = next(reader)
for row in reader:
metadata_values = {m: row[i] for i, m in enumerate(metadata)}
times = []
for t in row[len(metadata):len(metadata) + 6]:
if ":" in t:
min, sec = map(int, t.split(":"))
times.append(min * 60 + sec)
else:
times.append(int(t))
experiment = cls(
metadata=metadata_values, filename_fmt=filename_fmt,
pre_start=times[0], pre_end=times[1],
trial_start=times[2], trial_end=times[3],
post_start=times[4], post_end=times[5], title_root=title_root_fmt.format(**metadata_values),
)
experiments.append(experiment)
return experiments
def parse_box_metadata(
self, data_root: Path, box_names: tuple[str] = ("farside", "nearside"),
cm_per_px: float | None = None,
):
filename = self.box_metadata_filename(data_root)
with open(filename, "r") as fh:
data = json.load(fh)
shape = None
for s in data["shapes"]:
for name in box_names:
if s["label"] == name:
shape = s
if shape is None:
raise ValueError(f"Cannot find {box_names} in the json file. {filename}")
points = np.array(shape["points"])
min_x, min_y = np.min(points, axis=0)
max_x, max_y = np.max(points, axis=0)
w = max_x - min_x
h = max_y - min_y
self.box_center = int(min_x + w / 2), int(min_y + h / 2)
self.box_size = int(w), int(h)
self.image_size = int(data["imageWidth"]), int(data["imageHeight"])
self.cm_per_px = cm_per_px
def set_box_metadata(
self, box_left: int, box_top: int, box_right: int, box_bottom: int, image_size: tuple[int, int],
cm_per_px: float | None = None,
):
self.box_center = int((box_right + box_left) / 2), int((box_bottom + box_top) / 2)
self.box_size = box_right - box_left, box_bottom - box_top
self.image_size = image_size
self.image_offset = 0, 0
self.enlarged_image_size = self.image_size
self.largest_box_size = self.box_size
self.cm_per_px = cm_per_px
@classmethod
def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]:
box_centers = np.array([e.box_center for e in experiments])
box_sizes = np.array([e.box_size for e in experiments])
image_sizes = np.array([e.image_size for e in experiments])
max_image = np.max(image_sizes, axis=0)
image_centers = np.floor(image_sizes / 2)
adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2)
adjusted_image_offsets = adjusted_image_centers - image_centers
adjusted_box_centers = box_centers + adjusted_image_offsets
mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0))
aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers
# how much we need to offset our image data so our box center aligns with the mean box center of all boxes
total_image_offset = adjusted_image_offsets + aligned_box_offsets
min_offset = np.min(total_image_offset, axis=0)
max_offset = np.max(total_image_offset, axis=0)
final_image_size = max_image + max_offset - min_offset
max_box_size = np.max(box_sizes, axis=0)
for i, experiment in enumerate(experiments):
experiment.image_offset = tuple(map(int, total_image_offset[i, :]))
experiment.enlarged_image_size = tuple(map(int, final_image_size))
experiment.largest_box_size = tuple(map(int, max_box_size))
return tuple(map(int, final_image_size))
def position_to_side(
self, x: int, y: int, split_horizontally: bool = True,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
):
cx, cy = self.box_center
cue_x, cue_y = self._descriptor_to_point(self.cue_point)
if split_horizontally:
if cue_x < cx:
i = 0 if x < cx else 1
else:
i = 0 if x >= cx else 1
else:
if cue_y < cy:
i = 0 if y < cy else 1
else:
i = 0 if y >= cy else 1
return categoricals[i]
def convert_to_categoricals(
self, data_group: Union["SubjectPos", str], measure: CAT_MEASURE_TYPE,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
measure_options: dict | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if measure_options is None:
measure_options = {}
if isinstance(data_group, str):
data_group = getattr(self, data_group)
if measure == "occupancy":
f = partial(
self.position_to_side,
split_horizontally=measure_options.get("split_horizontally", True), categoricals=categoricals,
)
times, time_diff, categoricals_index = data_group.transform_to_categorical_pos(f, categoricals)
elif "freezing" in measure:
freeze_i = categoricals.index(measure_options["freeze_name"])
non_freeze_i = 1 if freeze_i == 0 else 0
if measure == "motion_index_freezing":
cm_per_px = data_group.cm_per_px or 1
index = data_group.motion_index[:-1] * cm_per_px
valid = index >= 0
times = data_group.times
time_diff = times[1:] - times[:-1]
assert np.all(time_diff >= 0)
index = index[valid]
time_diff = time_diff[valid]
times = times[:-1][valid]
else:
assert measure == "speed_freezing"
times, index, _, time_diff = data_group.calculate_speed()
categoricals_index = np.empty(len(times))
freezing = index <= measure_options["threshold"]
categoricals_index[freezing] = freeze_i
categoricals_index[np.logical_not(freezing)] = non_freeze_i
else:
raise ValueError(f"Unknown measure {measure}")
return times, time_diff, categoricals_index
def position_to_grid_index(
self, x, y, grid_width: int, grid_height: int, y_at_top: bool = False
) -> tuple[int, int]:
offset_x, offset_y = self.image_offset
cx, cy = self.box_center
bw, bh = self.box_size
largest_bw, largest_bh = self.largest_box_size
width_scale = largest_bw / bw
height_scale = largest_bh / bh
x = (x - cx) * width_scale + cx + offset_x
y = (y - cy) * height_scale + cy + offset_y
x = int(min(max(round(x), 0), grid_width - 1))
y = int(min(max(round(y), 0), grid_height - 1))
if y_at_top:
y = grid_height - 1 - y
return x, y
def read_pos_track(
self, data_root: Path, pre_offset: float = 0, pre_duration: float | None = None, trial_offset: float = 0,
trial_duration: float | None = None, post_offset: float = 0, post_duration: float | None = None,
frame_rate: float = 30.
) -> None:
self.pos_data = SubjectPos.parse_csv_track(
self.pos_path_filename(data_root), frame_rate=frame_rate, cm_per_px=self.cm_per_px
)
start = self.pre_start + pre_offset
end = self.pre_end
if pre_duration and pre_duration > 0:
end = min(end, start + pre_duration)
elif pre_duration and pre_duration < 0:
start = max(start, end + pre_duration)
self.pre_pos_data = self.pos_data.extract_range(start, end)
start = self.trial_start + trial_offset
end = self.trial_end
if trial_duration and trial_duration > 0:
end = min(end, start + trial_duration)
elif trial_duration and trial_duration < 0:
start = max(start, end + trial_duration)
self.trial_pos_data = self.pos_data.extract_range(start, end)
start = self.post_start + post_offset
end = self.post_end
if post_duration and post_duration > 0:
end = min(end, start + post_duration)
elif post_duration and post_duration < 0:
start = max(start, end + post_duration)
self.post_pos_data = self.pos_data.extract_range(start, end)
def chop_pos_track(
self, hab_segments: list[tuple[float, float]] = (), trial_segments: list[tuple[float, float]] = (),
post_segments: list[tuple[float, float]] = (),
) -> None:
data = self.chopped_pos_data = {}
for name in ("Pre", "Trial", "Post"):
match name:
case "Pre":
segments = hab_segments
ts = self.pre_start
te = self.pre_end
case "Trial":
segments = trial_segments
ts = self.trial_start
te = self.trial_end
case "Post":
segments = post_segments
ts = self.post_start
te = self.post_end
case _:
assert False
for start, end in segments:
item = self.pos_data.extract_range(ts + start, min(ts + end, te))
key = f"{name}_{int(start)}" if len(segments) > 1 else name
data[key] = item
def extract_subject_period(
self, hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (),
post_segment: tuple[float, float] = (),
) -> Optional["SubjectPos"]:
if hab_segment:
segment = hab_segment
ts = self.pre_start
te = self.pre_end
elif trial_segment:
segment = trial_segment
ts = self.trial_start
te = self.trial_end
elif post_segment:
segment = post_segment
ts = self.post_start
te = self.post_end
else:
return None
start, end = segment
item = self.pos_data.extract_range(ts + start, min(ts + end, te))
return item
@classmethod
def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]:
a, b, c = cls.triplet_name
if obj is None:
res = [
("pre_pos_data", a),
("trial_pos_data", b),
("post_pos_data", c),
]
else:
res = [
(obj.pre_pos_data, a),
(obj.trial_pos_data, b),
(obj.post_pos_data, c),
]
return res
@classmethod
def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None):
n_groups = 1 if filter_args is None else len(filter_args)
it = iter(axs.flatten())
for i, filter_group in enumerate(filter_args or [None, ]):
if n_groups > 1:
experiments_ = cls.filter(experiments, **filter_group)
else:
experiments_ = experiments
for j, (data_name, t) in enumerate(periods):
ax = next(it)
yield experiments_, i, filter_group, j, data_name, t, ax
def plot_occupancy(
self,
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
show_point: bool = True, point: tuple[float | str, float | str] | None = None,
):
if point is None:
point = self.cue_point
point = self._descriptor_to_point(point)
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
unit = "cm" if self.cm_per_px else "pixels"
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_occupancy(
*self.enlarged_image_size, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax,
gaussian_sigma=gaussian_sigma,
intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one,
color_bar=not i,
x_label=f"X ({unit})",
y_label=f"Y ({unit})" if not i else "",
title="",
)
ax.set_title(f"$\\bf{{{title}}}$")
if show_point:
cm_per_px = data.cm_per_px or 1
x, y = self.position_to_grid_index(*point, *self.enlarged_image_size)
ax.plot([x * cm_per_px], [y * cm_per_px], "*r")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} occupancy density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_occupancy(
cls, experiments: list["Experiment"],
gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density",
frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
show_point: bool = True, point: tuple[float | str, float | str] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
if not experiments_:
continue
unit = "cm" if experiments_[0].cm_per_px else "pixels"
cm_per_px = experiments_[0].cm_per_px
assert len({exp.cm_per_px for exp in experiments_}) == 1
grid_size = {e.enlarged_image_size for e in experiments_}
if len(grid_size) != 1:
raise ValueError("Trying to plot occupancy of experiments with different sizes")
grid_size = list(grid_size)[0]
occupancy = np.zeros(grid_size)
for experiment in experiments_:
getattr(experiment, data_name).calculate_occupancy(
occupancy, experiment.position_to_grid_index, frame_normalize,
)
if experiment_normalize:
occupancy /= len(experiments_)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
SubjectPos.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one,
color_bar=not i and not j,
title="",
x_label=f"X ({unit})" if i == n_groups - 1 else "",
y_label=f"{group}Y ({unit})" if not j else "",
cm_per_px=cm_per_px,
)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
if show_point:
exp_point = experiments_[0].cue_point if point is None else point
exp_point = experiments_[0]._descriptor_to_point(exp_point)
x, y = experiments_[0].position_to_grid_index(*exp_point, *grid_size)
ax.plot([x * (cm_per_px or 1)], [y * (cm_per_px or 1)], "*r")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def export_multi_experiment_motion(
cls, experiments: list["Experiment"], filename: Path,
measure: Literal["motion_index", "speed"] = "motion_index", use_chopped_data: bool = False,
) -> None:
unit = "px" if experiments[0].cm_per_px is None else "cm"
header = experiments[0].get_header_id_columns() + ["Stage", f"Motion mean ({unit}/s)"]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
cm_per_px = experiment.cm_per_px or 1
if use_chopped_data:
stages = list(experiment.chopped_pos_data.keys())
stages_data = [experiment.chopped_pos_data[stage] for stage in stages]
else:
stages = "Pre", "Trial", "Post"
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
id_columns = experiment.get_id_columns()
for stage, data in zip(stages, stages_data):
if measure == "motion_index":
motion = data.motion_index * cm_per_px
elif measure == "speed":
_, motion, _, _ = data.calculate_speed()
motion = motion[motion >= 0]
mean_motion = np.mean(motion) if len(motion) else 0
line = id_columns + [stage, mean_motion]
writer.writerow(map(str, line))
@classmethod
def export_bouts(
cls, experiments: list["Experiment"], filename: Path, measure: BOUT_MEASURE_TYPE,
threshold: float, frame_rate: float, refractory_period: float = 0,
downsample_factor: int = 1, use_chopped_data: bool = False, min_bout_frames: int = 0,
unit: Literal["percent", "times"] = "times",
) -> None:
header = experiments[0].get_header_id_columns() + [
"Stage",
"Number of bouts" + ("" if unit == "times" else " / s"),
"Mean duration (s)",
"Total duration " + ("(s)" if unit == "times" else "(percent)"),
]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
if use_chopped_data:
stages = list(experiment.chopped_pos_data.keys())
stages_data = [experiment.chopped_pos_data[stage] for stage in stages]
else:
stages = "Pre", "Trial", "Post"
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
id_columns = experiment.get_id_columns()
for stage, data in zip(stages, stages_data):
bouts = data.calculate_measure_bout(
measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames,
)
dur = sum(b[1] for b in bouts)
n = len(bouts)
dur_mean = (dur / len(bouts)) if bouts else 0
if unit == "percent":
dur = dur / data.duration * 100
n /= data.duration
line = id_columns + [stage, n, dur_mean, dur]
writer.writerow(map(str, line))
def plot_motion_index(
self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = ""
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_motion_index(
fig, ax, **{"y_label": ""} if i else {},
)
if y_limit is not None:
ax.set_ylim(0, y_limit)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} motion index")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion(
cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
measure: Literal["motion_index", "speed"] = "motion_index",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
for experiment in experiments_:
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
label = "Motion index" if measure == "motion_index" else "Speed"
if experiment.cm_per_px:
label += " (cm/s)"
else:
label += " (px/s)"
kwargs["y_label"] = f"{group}{label}"
data = getattr(experiment, data_name)
getattr(data, f"plot_{measure}")(fig, ax, **kwargs)
if y_limit is not None:
ax.set_ylim(0, y_limit)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion_index_histogram(
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "",
hist_range: tuple[float, float] = (0, 3), log_x: bool = False,
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
cm_per_px = self.cm_per_px or 1
unit = "px" if self.cm_per_px is None else "cm"
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data = data.motion_index * cm_per_px
if log_x:
total = len(data)
data = data[data > 0]
if not len(data):
continue
non_zero = len(data)
percent_zero = (total - non_zero) / total * 100
ax.hist(data, bins=get_log_bins(data, n_bins), density=True)
ax.set_xscale("log")
ax.set_title(f"$\\bf{{{title}}}$ zeros = {percent_zero:0.2f}%")
else:
data = data[data >= 0]
if not len(data):
continue
ax.hist(data, bins=n_bins, density=True)
ax.set_title(f"$\\bf{{{title}}}$")
ax.set_xlim(*hist_range)
ax.set_xlabel(f"Motion index ({unit}/s)")
if not i:
ax.set_ylabel("Density")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} motion index density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion_index_histogram(
cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "", hist_range: tuple[float, float] = (0, 3),
filter_args: list[dict] | None = None, group_label: str = "", log_x: bool = False,
):
if not experiments:
return
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
unit = "px" if experiments[0].cm_per_px is None else "cm"
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
if not experiments_:
continue
items = []
for experiment in experiments_:
cm_per_px = experiment.cm_per_px or 1
data = getattr(experiment, data_name).motion_index * cm_per_px
items.append(data[data >= 0])
if log_x:
data = np.concatenate(items)
total = len(data)
data = data[data > 0]
if not len(data):
continue
non_zero = len(data)
percent_zero = (total - non_zero) / total * 100
ax.hist(data, bins=get_log_bins(data, n_bins), density=True)
ax.set_xscale("log")
if i:
ax.set_title(f"zeros = {percent_zero:0.2f}%")
else:
ax.set_title(f"$\\bf{{{t}}}$ zeros = {percent_zero:0.2f}%")
else:
if items:
data = np.concatenate(items)
if not len(data):
continue
ax.hist(data, bins=n_bins, density=True)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
ax.set_xlim(*hist_range)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
if i == n_groups - 1:
ax.set_xlabel(f"Motion index ({unit}/s)")
if not j:
ax.set_ylabel(f"{group}Density")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed(
self, y_limit: float | None = None, save_fig_root: None | Path = None,
save_fig_prefix: str = ""
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_speed(
fig, ax, **{"y_label": ""} if i else {},
)
if y_limit is not None:
ax.set_ylim(0, y_limit)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} speed")
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed_histogram(
self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = "",
hist_range: tuple[float, float] = (0, 200), log_x: bool = False,
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
_, data, _, _ = data.calculate_speed()
if log_x:
total = len(data)
data = data[data > 0]
if not len(data):
continue
non_zero = len(data)
percent_zero = (total - non_zero) / total * 100
ax.hist(data, bins=get_log_bins(data, n_bins), density=True)
ax.set_xscale("log")
ax.set_title(f"$\\bf{{{title}}}$ zeros = {percent_zero:0.2f}%")
else:
data = data[data >= 0]
if not len(data):
continue
ax.hist(data, bins=n_bins, density=True)
ax.set_title(f"$\\bf{{{title}}}$")
ax.set_xlim(*hist_range)
ax.set_xlabel("Speed (cm/s)" if self.cm_per_px else "Speed (px/s)")
if not i:
ax.set_ylabel("Density")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} Speed density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_speed_histogram(
cls, experiments: list["Experiment"], n_bins=100,
title: str = "Subjects Speed", hist_range: tuple[float, float] = (0, 200),
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "", log_x: bool = False,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
if not experiments_:
continue
speed_label = "Speed (cm/s)" if experiments_ and experiments_[0].cm_per_px else "Speed (px/s)"
assert len({exp.cm_per_px for exp in experiments_}) == 1
items = []
for experiment in experiments_:
_, data, _, _ = getattr(experiment, data_name).calculate_speed()
items.append(data[data >= 0])
if log_x:
data = np.concatenate(items)
total = len(data)
data = data[data > 0]
if not len(data):
continue
non_zero = len(data)
percent_zero = (total - non_zero) / total * 100
ax.hist(data, bins=get_log_bins(data, n_bins), density=True)
ax.set_xscale("log")
if i:
ax.set_title(f"zeros = {percent_zero:0.2f}%")
else:
ax.set_title(f"$\\bf{{{t}}}$ zeros = {percent_zero:0.2f}%")
else:
if items:
data = np.concatenate(items)
if len(data):
ax.hist(data, bins=n_bins, density=True)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
ax.set_xlim(*hist_range)
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
if i == n_groups - 1:
ax.set_xlabel(speed_label)
if not j:
ax.set_ylabel(f"{group}Density")
fig.suptitle(f"{title}")
save_or_show(save_fig_root, save_fig_prefix)
def plot_categorical_values(
self, measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None,
save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for (data, title), ax in zip(self._get_data_items(self), axs.flatten()):
times, _, categoricals_index = self.convert_to_categoricals(data, measure, categoricals, measure_options)
data.plot_categorical_values(times, categoricals_index, categoricals, fig, ax)
ax.set_title(title)
label = self.title_root.format(**self.metadata)
fig.suptitle(f"Subject {label}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_side_of_box(
cls, experiments: list["Experiment"], measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()):
for experiment in experiments:
times, _, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
getattr(experiment, data_name).plot_categorical_values(times, categoricals_index, categoricals, fig, ax)
ax.set_title(t)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def _descriptor_to_point(self, point: tuple[str, str] | tuple[float, float]):
cx, cy = self.box_center
bw, bh = self.box_size
match point[0]:
case "left":
x = cx - bw / 2
case "right":
x = cx + bw / 2
case "center":
x = cx
case int() | float():
x = point[0]
case _:
raise ValueError(f"Can't recognize {point[0]}")
match point[1]:
case "bottom":
y = cy + bh / 2
case "top":
y = cy - bh / 2
case "center":
y = cy
case int() | float():
y = point[1]
case _:
raise ValueError(f"Can't recognize {point[1]}")
return x, y
def plot_distance_from_point(
self, point: tuple[str, str] | None = None, save_fig_root: None | Path = None,
save_fig_prefix: str = "", post_title: str = "distance from teaball corner",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
if point is None:
point = self.cue_point
point_xy = self._descriptor_to_point(point)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_distance_from_point(
point_xy, fig, ax, **{"y_label": ""} if i else {},
)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
fig.suptitle(f"{label} {post_title}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_distance_from_point(
cls, experiments: list["Experiment"], point: tuple[str, str] | None = None,
title: str = "Subjects distance from teaball corner",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
if not experiments_:
continue
dist_label = "cm" if experiments_ and experiments_[0].cm_per_px else "px"
assert len({exp.cm_per_px for exp in experiments_}) == 1
for experiment in experiments_:
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}Distance ({dist_label})"
point_xy = experiment._descriptor_to_point(experiment.cue_point if point is None else point)
getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_bouts(
self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, refractory_period: float = 0,
downsample_factor: int = 1, min_bout_frames: int = 0, save_fig_root: None | Path = None,
save_fig_prefix: str = "", plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter",
unit: Literal["percent", "times"] = "times",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
x_label = "Time (min)" if plot_type == "scatter" else "Bouts"
match plot_type:
case "scatter":
y_label = "Duration (s)"
case "count":
y_label = "Count" if unit == "times" else "Count / s"
case "mean":
y_label = "Mean duration (s)"
case "duration":
y_label = "Total duration " + ("(s)" if unit == "times" else "(percent)")
case _:
y_label = ""
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
bouts = data.plot_bout(
measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames, plot_type, fig,
ax, x_label=x_label, norm_time=unit == "percent", **{"y_label": "" if i else y_label},
)
ax.set_title(f"$\\bf{{{title}}}$ {len(bouts)} bouts")
label = self.title_root.format(**self.metadata)
name = measure.capitalize().replace("_", " ")
fig.suptitle(f"{label} {name} bouts")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_bouts(
cls, experiments: list["Experiment"], measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float,
refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter", unit: Literal["percent", "times"] = "times",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
x_label = "Time (min)" if plot_type == "scatter" else "Bouts"
match plot_type:
case "scatter":
y_label = "Duration (s)"
case "count":
y_label = "Count" if unit == "times" else "Count / s"
case "mean":
y_label = "Mean duration (s)"
case "duration":
y_label = "Total duration " + ("(s)" if unit == "times" else "(percent)")
case _:
y_label = ""
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
if not experiments_:
continue
assert len({exp.cm_per_px for exp in experiments_}) == 1
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": "", "x_label": x_label}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}{y_label}"
if plot_type == "scatter":
for experiment in experiments_:
getattr(experiment, data_name).plot_bout(
measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames,
plot_type, fig, ax, **kwargs
)
else:
all_bouts = [
getattr(experiment, data_name).calculate_measure_bout(
measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames
)
for experiment in experiments_
]
durations = [getattr(experiment, data_name).duration for experiment in experiments_]
durations = np.array(durations, dtype=np.float64)
match plot_type:
case "count":
values = np.array([len(bouts) for bouts in all_bouts], dtype=np.float64)
if unit == "percent":
values /= durations
case "mean":
values = np.array(
[sum(b[1] for b in bouts) / len(bouts) if bouts else 0 for bouts in all_bouts],
dtype=np.float64
)
case "duration":
values = np.array([sum(b[1] for b in bouts) for bouts in all_bouts], dtype=np.float64)
if unit == "percent":
values = values / durations * 100
case _:
assert False
ax.bar([0], [np.mean(values)])
ax.plot([0] * len(values), values, "k.")
ax.set_xticks([0], [""])
if kwargs["x_label"]:
ax.set_xlabel(kwargs["x_label"])
if kwargs["y_label"]:
ax.set_ylabel(kwargs["y_label"])
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
name = measure.capitalize().replace("_", " ")
fig.suptitle(f"{name} bouts")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def _get_categorical_percent(
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray],
sorted_categoricals: tuple[str, ...], unit: Literal["percent", "times"] = "times",
):
if unit == "times":
percents = np.empty((len(categoricals_index), len(sorted_categoricals)))
for e, (exp_categoricals_index, time_diff) in enumerate(zip(categoricals_index, times_diff)):
if len(exp_categoricals_index) != len(time_diff):
raise ValueError("Provided time diff and categorical index are not the same length")
for i in range(len(sorted_categoricals)):
percents[e, i] = np.sum(time_diff[exp_categoricals_index == i])
elif unit == "percent":
counts = np.array([
[np.sum(arr == i) for i in range(len(sorted_categoricals))]
for arr in categoricals_index
])
percents = counts / np.sum(counts, axis=1, keepdims=True) * 100
else:
raise ValueError(f"Unknown unit {unit}")
mean_prop = np.mean(percents, axis=0).squeeze()
return mean_prop, [prop.squeeze() for prop in percents]
@classmethod
def _plot_percents(
cls, categoricals_index: list[np.ndarray], times_diff: list[np.ndarray],
sorted_categoricals: tuple[str, ...],
fig: plt.Figure, ax: plt.Axes,
x_label: str = "Teaball side", y_label: str = "% time spent",
unit: Literal["percent", "times"] = "times",
):
mean_prop, percents = cls._get_categorical_percent(categoricals_index, times_diff, sorted_categoricals, unit)
ax.bar(np.arange(len(sorted_categoricals)), mean_prop, tick_label=sorted_categoricals)
if len(categoricals_index) > 1:
for prop in percents:
ax.plot(np.arange(len(sorted_categoricals)), prop.squeeze(), "k.")
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
def plot_categorical_percent(
self, measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
save_fig_root: None | Path = None, save_fig_prefix: str = "",
unit: Literal["percent", "times"] = "times", x_label="Odor side", action_label: str = "spent in side",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
y_label = "% time spent" if unit == "percent" else "total time spent (s)"
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
_, time_diff, categoricals_index = self.convert_to_categoricals(data, measure, categoricals,
measure_options)
self._plot_percents(
[categoricals_index], [time_diff], categoricals, fig, ax, x_label, "" if i else y_label, unit,
)
ax.set_title(f"$\\bf{{{title}}}$")
label = self.title_root.format(**self.metadata)
tp = "total" if unit == "times" else "%"
fig.suptitle(f"{label} {tp} time {action_label}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_categorical_percent(
cls, experiments: list["Experiment"], measure: CAT_MEASURE_TYPE = "occupancy",
measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None, group_label: str = "",
unit: Literal["percent", "times"] = "times", x_label: str = "Odor side",
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
if not times_all:
continue
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
kwargs = {"y_label": "", "x_label": x_label}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
label = "% time spent" if unit == "percent" else "total time spent (s)"
kwargs["y_label"] = f"{group}{label}"
cls._plot_percents(categoricals_index_all, times_all, categoricals, fig, ax, **kwargs, unit=unit)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def export_multi_experiment_categorical_percent(
cls, experiments: list["Experiment"], filename: Path,
measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"),
unit: Literal["percent", "times"] = "times",
use_chopped_data: bool = False,
) -> None:
sign = "sec" if unit == "times" else "%"
header = experiments[0].get_header_id_columns() + ["Stage"] + [
f"{label} {sign}" for label in categoricals
]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
if use_chopped_data:
stages = list(experiment.chopped_pos_data.keys())
stages_data = [experiment.chopped_pos_data[stage] for stage in stages]
else:
stages = "Pre", "Trial", "Post"
stages_data = experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
id_columns = experiment.get_id_columns()
for stage, data in zip(stages, stages_data):
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data, measure, categoricals, measure_options
)
mean_prop, _ = cls._get_categorical_percent([categoricals_index], [times_diff], categoricals, unit)
line = id_columns + [stage, *mean_prop]
writer.writerow(map(str, line))
@classmethod
def plot_multi_experiment_merged_by_period_categorical_percent(
cls, experiments: list["Experiment"], filter_args: list[dict],
measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "", group_label: str = "",
unit: Literal["percent", "times"] = "times", x_label: str = "Teaball side",
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_groups + 1)
n_categoricals = len(categoricals)
for i, filter_group in enumerate(filter_args):
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
if not times_all:
continue
ax = axes[j]
mean_prop, percents = cls._get_categorical_percent(
categoricals_index_all, times_all, categoricals, unit
)
ax.bar(
i * bar_width + np.arange(n_categoricals), mean_prop, bar_width,
label=group_label.format(**filter_group),
)
for prop in percents:
ax.plot(i * bar_width + np.arange(n_categoricals), prop, "k.")
ax.set_xlabel(x_label)
if not j:
label = "% time spent" if unit == "percent" else "total time spent (s)"
ax.set_ylabel(f"{label}")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
for ax in axes:
ax.set_xticks(np.arange(n_categoricals) + (1 - bar_width) / 2 - bar_width / 2, categoricals)
ax.set_xlim(-bar_width, n_categoricals + bar_width)
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_merged_by_group_categorical_percent(
cls, experiments: list["Experiment"], filter_args: list[dict],
measure: CAT_MEASURE_TYPE = "occupancy", measure_options: dict = None,
categoricals: tuple[str, ...] = ("Teaball-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True,
show_legend: bool = True, only_categorical: str | None = None, show_xlabel: bool = True,
unit: Literal["percent", "times"] = "times", group_label: str = "", x_label: str = "Teaball side"
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_periods + 1)
n_categoricals = len(categoricals)
categoricals_s = slice(0, n_categoricals)
n_categoricals_used = n_categoricals
categoricals_used = categoricals
if only_categorical:
i = categoricals.index(only_categorical)
categoricals_s = slice(i, i + 1)
n_categoricals_used = 1
categoricals_used = [only_categorical]
for i, filter_group in enumerate(filter_args):
group = ""
if n_groups > 1:
group = group_label.format(**filter_group)
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
categoricals_index_all = []
times_all = []
for experiment in experiments_:
_, times_diff, categoricals_index = experiment.convert_to_categoricals(
data_name, measure, categoricals, measure_options
)
categoricals_index_all.append(categoricals_index)
times_all.append(times_diff)
if not times_all:
continue
ax = axes[i]
mean_prop, percents = cls._get_categorical_percent(
categoricals_index_all, times_all, categoricals, unit
)
ax.bar(j * bar_width + np.arange(n_categoricals_used), mean_prop[categoricals_s], bar_width, label=t)
for prop in percents:
ax.plot(j * bar_width + np.arange(n_categoricals_used), prop[categoricals_s], "k.")
if i == n_groups - 1 and show_xlabel:
ax.set_xlabel(x_label)
label = "% time spent" if unit == "percent" else "total time spent (s)"
ax.set_ylabel(f"{group}{label}")
if not j and show_titles:
ax.set_title(group_label.format(**filter_group))
for ax in axes:
ax.set_xticks(np.arange(n_categoricals_used) + (1 - bar_width) / 2 - bar_width / 2, categoricals_used)
ax.set_xlim(-bar_width, n_categoricals_used - bar_width)
if show_legend:
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix, width_inch=3)
@classmethod
def export_multi_experiment_frames(cls, experiments: list["Experiment"], filename: Path) -> None:
header = experiments[0].get_header_id_columns() + [
"box_center_x_px", "box_center_y_px", "box_width_px", "box_height_px", "image_width_px", "image_height_px",
"cm_per_px", "total_frames", "pre_frames", "trial_frames", "post_frames", "total_duration", "pre_duration",
"trial_duration", "post_duration",
]
filename.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w", newline="") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerow(header)
for experiment in experiments:
data_obj = (
experiment.pos_data, experiment.pre_pos_data, experiment.trial_pos_data, experiment.post_pos_data
)
line = experiment.get_id_columns() + [
*experiment.box_center,
*experiment.box_size,
*experiment.image_size,
experiment.cm_per_px,
*(len(obj.times) for obj in data_obj),
*((obj.times[-1] - obj.times[0]) if len(obj.times) else 0 for obj in data_obj),
]
writer.writerow(map(str, line))
@classmethod
def filter(cls, experiments: list["Experiment"], **metadata):
for key, value in metadata.items():
if value is not None:
experiments = [t for t in experiments if t.metadata[key] == value]
return experiments
@classmethod
def count_motion_range(
cls, experiments: list["Experiment"], measure: Literal["speed", "motion_index"] = "motion_index",
downsample_factor: int = 1, hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (),
post_segment: tuple[float, float] = (),
) -> dict[float, int]:
items = []
for exp in experiments:
data = exp.pos_data
if hab_segment or trial_segment or post_segment:
data = exp.extract_subject_period(hab_segment, trial_segment, post_segment)
cm_per_px = exp.cm_per_px or 1
if measure == "motion_index":
motion = data.motion_index * cm_per_px
else:
_, motion, _, _ = data.calculate_speed()
if downsample_factor != 1:
motion = decimate(motion, downsample_factor)
items.append(motion[motion >= 0])
data = np.concatenate(items)
n = len(data)
counts = {}
zero_mask = data == 0
counts[0] = np.sum(zero_mask).item()
data = data[np.logical_not(zero_mask)]
if not len(data):
return counts
min_val = np.min(data)
assert min_val > 0
max_val = np.max(data)
min_order = int(math.floor(math.log10(min_val)))
max_order = max(int(math.floor(math.log10(max_val))), 0)
for order in range(max_order, min_order - 1, -1):
for offset in (0, math.log10(0.5)):
val = math.pow(10, order + offset)
mask = data >= val
counts[val] = np.sum(mask).item()
data = data[np.logical_not(mask)]
assert not len(data)
assert sum(counts.values()) == n
return counts
@classmethod
def plot_motion_index_range(
cls, experiments: list["Experiment"], ax: plt.Axes = None,
measure: Literal["speed", "motion_index"] = "motion_index",
) -> None:
has_ax = ax is not None
items = []
for exp in experiments:
cm_per_px = exp.cm_per_px or 1
if measure == "motion_index":
motion = exp.pos_data.motion_index * cm_per_px
else:
_, motion, _, _ = exp.pos_data.calculate_speed()
items.append(motion[motion > 0])
data = np.concatenate(items)
if not has_ax:
fig, ax = plt.subplots()
ax.hist(data, bins=get_log_bins(data, 100), density=True)
ax.set_xscale("log")
ax.set_xlabel("Motion index" if measure == "motion_index" else "Speed")
ax.set_ylabel("Density")
if not has_ax:
plt.show()
@classmethod
def plot_compare_motion_index_range(
cls, data1: Path, data2: Path, title1: str, title2: str, frame_rate,
measure: Literal["speed", "motion_index"] = "motion_index", cm_per_px: float | None = None,
) -> None:
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True)
pos_data = SubjectPos.parse_csv_track(data1, frame_rate=frame_rate, cm_per_px=cm_per_px)
exp = Experiment(0, 0, 0, 0, 0, 0, "", "")
exp.pos_data = pos_data
Experiment.plot_motion_index_range([exp], ax=ax1, measure=measure)
ax1.set_title(title1)
pos_data = SubjectPos.parse_csv_track(data2, frame_rate=frame_rate, cm_per_px=cm_per_px)
exp = Experiment(0, 0, 0, 0, 0, 0, "", "")
exp.pos_data = pos_data
Experiment.plot_motion_index_range([exp], ax=ax2, measure=measure)
ax2.set_title(title2)
plt.show()
@classmethod
def percentile_motion_range(
cls, subjects: list["SubjectPos"], percentiles: float | Sequence[float],
measure: Literal["speed", "motion_index"] = "motion_index", downsample_factor: int = 1,
) -> float | np.ndarray | np.floating:
items = []
for subject in subjects:
cm_per_px = subject.cm_per_px or 1
if measure == "motion_index":
motion = subject.motion_index * cm_per_px
else:
_, motion, _, _ = subject.calculate_speed()
if downsample_factor != 1:
motion = decimate(motion, downsample_factor)
items.append(motion[motion >= 0])
data = np.concatenate(items)
return np.percentile(data, percentiles)
def get_y_range_from_data(
self, measure: Literal["motion_index", "speed"], ylog: bool,
ylim: tuple[float, float] | None
) -> tuple[float, float]:
if ylim:
return ylim
cm = self.cm_per_px or 1
low = float("inf")
high = float("-inf")
for data in (self.pre_pos_data, self.trial_pos_data, self.post_pos_data):
index = data.motion_index * cm if measure == "motion_index" else data.calculate_speed(True)[1]
if ylog:
index = index[index > 0]
else:
index = index[index >= 0]
low = min(np.min(index), low)
high = max(np.max(index), high)
if ylog:
low = 10 ** (math.floor(2 * math.log10(low)) / 2)
high = 10 ** (math.ceil(2 * math.log10(high)) / 2)
else:
low = round(low)
high = round(high)
return low, high
def generate_images_from_data(
self, measure: Literal["motion_index", "speed"], ylog: bool,
ylim: tuple[float, float] | None, frame_rate: float, time_window_padding: float,
image_size: tuple[int, int], output_fig_h: int, root_output: Path,
bouts_parameters: dict[str, dict[str, float]], thresholds: dict[str, float],
) -> str:
root_output = root_output / measure
root_output.mkdir(parents=True, exist_ok=True)
cm = self.cm_per_px or 1
low, high = self.get_y_range_from_data(measure, ylog, ylim)
data = self.pos_data
index = data.motion_index * cm if measure == "motion_index" else data.calculate_speed(True)[1]
freezing_bouts = data.calculate_bouts(
index, above_threshold=False, frame_rate=frame_rate, **bouts_parameters[f"{measure}_freezing"],
)
dart_bouts = data.calculate_bouts(
index, above_threshold=True, frame_rate=frame_rate, **bouts_parameters[f"{measure}_dart"],
)
freezing_threshold = thresholds[f"{measure}_freezing"]
dart_threshold = thresholds[f"{measure}_dart"]
froze = index <= freezing_threshold
darted = index >= dart_threshold
n = len(index)
pad_frames = int(time_window_padding * frame_rate)
all_times = np.arange(n) / frame_rate
padding = int(math.ceil(math.log10(n)))
fig, ax = plt.subplots()
dpi = 150
fig.set_dpi(dpi)
fig.set_size_inches(image_size[0] / dpi, output_fig_h / dpi)
unit = "px" if not self.cm_per_px else "cm"
ax.set_ylabel(
measure.capitalize().replace("_", " ") + f" ({unit}/s)"
)
ax.set_ylim(low, high)
if ylog:
ax.set_yscale("log")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_visible(False)
ax.plot(all_times, index, "k")
ax.plot(all_times[froze], index[froze], ".r")
ax.plot(all_times[darted], index[darted], ".g")
for bouts, y, c in [(freezing_bouts, low, "r"), (dart_bouts, high, "g")]:
for ts, duration in bouts:
ax.plot([ts, ts + duration], [y, y], f"-{c}")
text = ax.text(
1, 1.15, f"Y:",
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes
)
for i in tqdm.tqdm(range(n), total=n, unit="frames", desc=f"{self.title_root} - {measure}"):
filename = root_output / f"{self.title_root}_{i:0{padding}}.png"
if filename.exists():
continue
value = index[i].item()
ts = (i - pad_frames) / frame_rate
te = (i + pad_frames) / frame_rate
lines2 = ax.plot([all_times[i]], [value], "*b")
text.set_text(f"Y: {value:0.5f}")
ax.set_ylim(low, high)
ax.set_xlim(ts, te)
fig.tight_layout()
fig.savefig(filename, dpi=dpi)
for lines in (lines2,):
for line in lines:
line.remove()
return str(root_output / f"{self.title_root}_%0{padding}d.png")
def merge_video_with_images(
self, parameters: "ExperimentParameters", frame_rate: float, image_size: tuple[int, int],
output_fig_h: int,
time_window_padding: float, video_filename: str | Path, ffmpeg: str | Path, temp_images: bool,
above_measure: Literal["motion_index", "speed"] | None,
below_measure: Literal["motion_index", "speed"] | None,
bouts_parameters: dict[str, dict[str, float]], thresholds: dict[str, float],
):
video_root = parameters.video_root
video_root.mkdir(parents=True, exist_ok=True)
tmp_dir = None
if temp_images:
tmp_dir = TemporaryDirectory()
image_root = Path(tmp_dir.name)
else:
image_root = parameters.image_root
image_root = image_root / self.title_root
image_root.mkdir(parents=True, exist_ok=True)
try:
above_arg = ()
if above_measure is not None:
img_pat = self.generate_images_from_data(
above_measure, True, None, frame_rate, time_window_padding, image_size,
output_fig_h, image_root, bouts_parameters, thresholds,
)
above_arg = "-r", str(frame_rate), "-i", img_pat
below_arg = ()
if below_measure is not None:
img_pat = self.generate_images_from_data(
below_measure, True, None, frame_rate, time_window_padding, image_size,
output_fig_h, image_root, bouts_parameters, thresholds,
)
below_arg = "-r", str(frame_rate), "-i", img_pat
if above_measure and below_measure:
filter_arg = "[0:v][1:v][2:v]vstack=inputs=3[v]"
elif above_measure or below_measure:
filter_arg = "[0:v][1:v]vstack=inputs=2[v]"
else:
filter_arg = "[0:v]vstack=inputs=1[v]"
args = [
str(ffmpeg),
*above_arg,
"-i",
str(video_filename),
*below_arg,
"-filter_complex",
filter_arg,
"-map",
"[v]",
"-c:v",
"libx264",
"-crf",
"18",
str(video_root / f"{self.title_root}_measures.mp4"),
]
print(f'Command: {" ".join(args)}')
p = subprocess.run(args, capture_output=True)
finally:
if tmp_dir is not None:
tmp_dir.cleanup()
print("stdout:\n")
print(p.stdout.decode())
print("stderr:\n")
print(p.stderr.decode())
p.check_returncode()
@classmethod
def export_experiments_csv(
cls, experiments: list["Experiment"], data_root: Path, occupancy_meas_args, cue_categoricals,
freezing_index_meas_args, freezing_speed_meas_args, freeze_categoricals,
cat_unit: Literal["times", "percent"], include_chopped: bool,
frame_rate: float, bouts_parameters: dict[str, dict[str, Any]],
):
cls.export_multi_experiment_frames(experiments, data_root / "experiment_metadata.csv")
cls.export_multi_experiment_categorical_percent(
experiments, data_root / "box_sides.csv", measure="occupancy", measure_options=occupancy_meas_args,
categoricals=cue_categoricals, unit=cat_unit,
)
cls.export_multi_experiment_categorical_percent(
experiments,
data_root / f"motion_index_freezing_threshold_{freezing_index_meas_args['threshold']:0.5f}.csv",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args, categoricals=freeze_categoricals, unit=cat_unit,
)
cls.export_multi_experiment_categorical_percent(
experiments, data_root / f"speed_freezing_threshold_{freezing_speed_meas_args['threshold']:0.5f}.csv",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
)
cls.export_multi_experiment_motion(experiments, data_root / "motion_index_mean.csv", measure="motion_index")
cls.export_multi_experiment_motion(experiments, data_root / "speed_mean.csv", measure="speed")
cls.export_bouts(
experiments, data_root / "bouts_speed_freezing.csv", "speed_freezing", unit=cat_unit,
frame_rate=frame_rate, **bouts_parameters["speed_freezing"],
)
cls.export_bouts(
experiments, data_root / "bouts_motion_index_freezing.csv", "motion_index_freezing", unit=cat_unit,
frame_rate=frame_rate, **bouts_parameters["motion_index_freezing"],
)
cls.export_bouts(
experiments, data_root / "bouts_motion_index_dart.csv", "motion_index_dart", unit=cat_unit,
frame_rate=frame_rate, **bouts_parameters["motion_index_dart"],
)
cls.export_bouts(
experiments, data_root / "bouts_speed_dart.csv", "speed_dart", frame_rate=frame_rate, unit=cat_unit,
**bouts_parameters["speed_dart"],
)
if include_chopped:
cls.export_multi_experiment_categorical_percent(
experiments, data_root / "box_sides_pre.csv", use_chopped_data=True,
measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit,
)
cls.export_multi_experiment_motion(
experiments, data_root / "motion_index_mean_pre.csv", use_chopped_data=True,
)
cls.export_multi_experiment_motion(
experiments, data_root / "speed_mean_pre.csv", use_chopped_data=True, measure="speed",
)
@classmethod
def export_single_subject_figures(
cls, experiments: list["Experiment"], figure_root: Path, occupancy_meas_args,
cat_unit: Literal["times", "percent"], prefix_pat: str, frame_rate: float,
bouts_parameters: dict[str, dict[str, Any]],
):
for i, experiment in tqdm.tqdm(list(enumerate(experiments)), desc="subject"):
label = prefix_pat.format(**experiment.metadata)
experiment.plot_occupancy(
scale_to_one=False, intensity_limit=1e-5,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"intensity_limit100K_{label}",
)
experiment.plot_occupancy(
intensity_limit=1e-6,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"intensity_limit1M_{label}",
)
experiment.plot_occupancy(
gaussian_sigma=5,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"occupancy_sigma5_{label}",
)
experiment.plot_occupancy(
gaussian_sigma=1,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"occupancy_sigma1_{label}",
)
experiment.plot_motion_index_histogram(
n_bins=50, hist_range=(0, .3),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"motion_index_histogram_{label}",
)
experiment.plot_motion_index_histogram(
n_bins=25, hist_range=(1e-5, .3),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"motion_index_histogram_log_{label}", log_x=True,
)
experiment.plot_motion_index(
y_limit=.2,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"motion_index_{label}",
)
experiment.plot_speed_histogram(
n_bins=50, hist_range=(0, 30),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"speed_histogram_{label}",
)
experiment.plot_speed_histogram(
n_bins=25, hist_range=(1e-3, 30),
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"speed_histogram_log_{label}", log_x=True,
)
experiment.plot_speed(
y_limit=20,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"speed_{label}",
)
experiment.plot_distance_from_point(
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"distance_{label}",
)
experiment.plot_categorical_percent(
measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit,
save_fig_root=figure_root / "subject" / label,
save_fig_prefix=f"side_{label}", x_label="",
)
for plot_type in ["scatter", "count", "mean", "duration"]:
experiment.plot_bouts(
"speed_freezing", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit,
save_fig_root=figure_root / "subject" / label / "bouts" / plot_type,
save_fig_prefix=f"speed_freezing_{label}", **bouts_parameters["speed_freezing"],
)
experiment.plot_bouts(
"motion_index_freezing", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit,
save_fig_root=figure_root / "subject" / label / "bouts" / plot_type,
save_fig_prefix=f"motion_index_freezing_{label}",
**bouts_parameters["motion_index_freezing"],
)
experiment.plot_bouts(
"motion_index_dart", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit,
save_fig_root=figure_root / "subject" / label / "bouts" / plot_type,
save_fig_prefix=f"motion_index_dart_{label}", **bouts_parameters["motion_index_dart"],
)
experiment.plot_bouts(
"speed_dart", frame_rate=frame_rate, plot_type=plot_type, unit=cat_unit,
save_fig_root=figure_root / "subject" / label / "bouts" / plot_type,
save_fig_prefix=f"speed_dart_{label}", **bouts_parameters["speed_dart"],
)
@classmethod
def export_multi_experiment_figures(
cls, experiments: list["Experiment"], figure_root: Path, occupancy_meas_args,
cat_unit: Literal["times", "percent"], label: str, cue_categoricals, freezing_index_meas_args,
freezing_speed_meas_args, freeze_categoricals, frame_rate: float,
bouts_parameters: dict[str, dict[str, Any]],
filter_args: list[dict] | None = None, group_label: str = "", sub_dir: str = "grouped",
):
cls.plot_multi_experiment_occupancy(
experiments,
title=f"{label} occupancy density", scale_to_one=False,
intensity_limit=1e-5, filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "occupancy" / "100K",
save_fig_prefix=f"occupancy_{label}_intensity_limit100K",
)
cls.plot_multi_experiment_occupancy(
experiments, filter_args=filter_args, group_label=group_label,
title=f"{label} occupancy density", intensity_limit=1e-6,
save_fig_root=figure_root / sub_dir / "occupancy" / "1M",
save_fig_prefix=f"occupancy_{label}_intensity_limit1M",
)
cls.plot_multi_experiment_occupancy(
experiments, title=f"{label} occupancy", gaussian_sigma=5,
save_fig_root=figure_root / sub_dir / "occupancy" / "sigma5",
save_fig_prefix=f"occupancy_{label}_sigma5",
filter_args=filter_args, group_label=group_label,
)
cls.plot_multi_experiment_occupancy(
experiments, title=f"{label} occupancy", gaussian_sigma=1,
save_fig_root=figure_root / sub_dir / "occupancy" / "sigma1",
save_fig_prefix=f"occupancy_{label}_sigma1",
filter_args=filter_args, group_label=group_label,
)
cls.plot_multi_experiment_motion_index_histogram(
experiments, title=f"{label} motion index density", n_bins=50, filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(0, .3),
save_fig_prefix=f"motion_index_histogram_{label}",
)
cls.plot_multi_experiment_motion_index_histogram(
experiments, title=f"{label} motion index density", n_bins=25, filter_args=filter_args,
group_label=group_label,
save_fig_root=figure_root / sub_dir / "motion_index_histogram", hist_range=(1e-5, .3),
save_fig_prefix=f"motion_index_histogram_log_{label}", log_x=True,
)
cls.plot_multi_experiment_motion(
experiments, y_limit=.2, title=f"{label} motion index", measure="motion_index",
save_fig_root=figure_root / sub_dir / "motion_index",
save_fig_prefix=f"motion_index_{label}",
filter_args=filter_args, group_label=group_label,
)
cls.plot_multi_experiment_speed_histogram(
experiments, title=f"{label} speed density", n_bins=50, filter_args=filter_args, group_label=group_label,
hist_range=(0, 30),
save_fig_root=figure_root / sub_dir / "speed_histogram",
save_fig_prefix=f"speed_histogram_{label}",
)
cls.plot_multi_experiment_speed_histogram(
experiments, title=f"{label} speed density", n_bins=25, filter_args=filter_args, group_label=group_label,
hist_range=(1e-3, 30),
save_fig_root=figure_root / sub_dir / "speed_histogram",
save_fig_prefix=f"speed_histogram_log_{label}", log_x=True,
)
cls.plot_multi_experiment_motion(
experiments, y_limit=20, title=f"{label} speed", measure="speed",
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "speed",
save_fig_prefix=f"speed_{label}",
)
cls.plot_multi_experiment_distance_from_point(
experiments, title=f"{label} distance from cue",
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "distance", save_fig_prefix=f"distance_{label}",
)
cls.plot_multi_experiment_categorical_percent(
experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args,
categoricals=cue_categoricals, unit=cat_unit,
filter_args=filter_args, group_label=group_label, x_label="",
save_fig_root=figure_root / sub_dir / "side" / "side",
save_fig_prefix=f"side_{label}",
)
cls.plot_multi_experiment_categorical_percent(
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
filter_args=filter_args, group_label=group_label, x_label="",
save_fig_root=figure_root / sub_dir / "motion_index_freezing" / "freezing",
save_fig_prefix=f"freezing_{label}",
)
cls.plot_multi_experiment_categorical_percent(
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, unit=cat_unit,
filter_args=filter_args, group_label=group_label, x_label="",
save_fig_root=figure_root / sub_dir / "speed_freezing" / "freezing",
save_fig_prefix=f"freezing_{label}",
)
for plot_type in ["scatter", "count", "mean", "duration"]:
cls.plot_multi_experiment_bouts(
experiments, "speed_freezing", frame_rate=frame_rate, plot_type=plot_type,
filter_args=filter_args, group_label=group_label, unit=cat_unit,
save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"speed_freezing_{label}",
**bouts_parameters["speed_freezing"],
)
cls.plot_multi_experiment_bouts(
experiments, "motion_index_freezing", frame_rate=frame_rate, plot_type=plot_type,
filter_args=filter_args, group_label=group_label, unit=cat_unit,
save_fig_root=figure_root / sub_dir / "bouts" / plot_type,
save_fig_prefix=f"motion_index_freezing_{label}",
**bouts_parameters["motion_index_freezing"],
)
cls.plot_multi_experiment_bouts(
experiments, "motion_index_dart", frame_rate=frame_rate, plot_type=plot_type,
filter_args=filter_args, group_label=group_label, unit=cat_unit,
save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"motion_index_dart_{label}",
**bouts_parameters["motion_index_dart"],
)
cls.plot_multi_experiment_bouts(
experiments, "speed_dart", frame_rate=frame_rate, plot_type=plot_type,
filter_args=filter_args, group_label=group_label, unit=cat_unit,
save_fig_root=figure_root / sub_dir / "bouts" / plot_type, save_fig_prefix=f"speed_dart_{label}",
**bouts_parameters["speed_dart"],
)
if not filter_args:
return
cls.plot_multi_experiment_merged_by_period_categorical_percent(
experiments, title=f"Side of box duration", measure="occupancy", measure_options=occupancy_meas_args,
categoricals=cue_categoricals, x_label="", unit=cat_unit,
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_period",
save_fig_prefix=f"side_merged_by_period_{label}",
)
cls.plot_multi_experiment_merged_by_group_categorical_percent(
experiments, title=f"Total time spent in cue side", show_titles=False, show_legend=True,
measure="occupancy", measure_options=occupancy_meas_args, unit=cat_unit,
categoricals=cue_categoricals, x_label="Cue side",
filter_args=filter_args, only_categorical="Cue-side", show_xlabel=False,
save_fig_root=figure_root / sub_dir / "side" / "side_merged_by_group",
save_fig_prefix=f"side_merged_by_group_{label}", group_label=group_label,
)
cls.plot_multi_experiment_merged_by_period_categorical_percent(
experiments, title=f"Motion index freezing duration. Threshold={freezing_index_meas_args['threshold']}",
measure="motion_index_freezing",
measure_options=freezing_index_meas_args,
categoricals=freeze_categoricals, x_label="", unit=cat_unit,
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "motion_index_freezing" /
"freezing_merged_by_period",
save_fig_prefix=f"freezing_merged_by_period_{label}",
)
cls.plot_multi_experiment_merged_by_group_categorical_percent(
experiments,
title=f"Total time spent freezing (motion index). Threshold={freezing_index_meas_args['threshold']}",
show_titles=False, show_legend=True,
measure="motion_index_freezing", measure_options=freezing_index_meas_args, unit=cat_unit,
categoricals=freeze_categoricals, x_label="Freezing",
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False,
save_fig_root=figure_root / sub_dir / "motion_index_freezing" /
"freezing_merged_by_group",
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label,
)
cls.plot_multi_experiment_merged_by_period_categorical_percent(
experiments, title=f"Speed freezing duration. Threshold={freezing_speed_meas_args['threshold']}",
measure="speed_freezing",
measure_options=freezing_speed_meas_args,
categoricals=freeze_categoricals, x_label="", unit=cat_unit,
filter_args=filter_args, group_label=group_label,
save_fig_root=figure_root / sub_dir / "speed_freezing" /
"freezing_merged_by_period",
save_fig_prefix=f"freezing_merged_by_period_{label}",
)
cls.plot_multi_experiment_merged_by_group_categorical_percent(
experiments,
title=f"Total time spent freezing (speed). Threshold={freezing_speed_meas_args['threshold']}",
show_titles=False, show_legend=True,
measure="speed_freezing", measure_options=freezing_speed_meas_args, unit=cat_unit,
categoricals=freeze_categoricals, x_label="Freezing",
filter_args=filter_args, only_categorical="Freezing", show_xlabel=False,
save_fig_root=figure_root / sub_dir / "speed_freezing" /
"freezing_merged_by_group",
save_fig_prefix=f"freezing_merged_by_group_{label}", group_label=group_label,
)
@classmethod
def print_measure_hist(
cls, experiments: list["Experiment"], percentile: float, downsample_factor: int = 1,
hab_segment: tuple[float, float] = (), trial_segment: tuple[float, float] = (),
post_segment: tuple[float, float] = (),
) -> None:
if sum(map(bool, [hab_segment, trial_segment, post_segment])) > 1:
raise ValueError("Only one of hab_segment, trial_segment, post_segment may be specified")
print(f"Downsampling measures at {downsample_factor} from original frame rate")
print("Motion index counts between number to next larger number:")
pprint.pprint(Experiment.count_motion_range(
experiments, measure="motion_index", downsample_factor=downsample_factor,
hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment,
))
print("Speed counts between number to next larger number:")
pprint.pprint(Experiment.count_motion_range(
experiments, measure="speed", downsample_factor=downsample_factor,
hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment,
))
data = [exp.pos_data for exp in experiments]
if hab_segment or trial_segment or post_segment:
data = [
exp.extract_subject_period(
hab_segment=hab_segment, trial_segment=trial_segment, post_segment=post_segment
)
for exp in experiments
]
motion_index_percentile = Experiment.percentile_motion_range(
data, percentile, measure="motion_index", downsample_factor=downsample_factor,
)
speed_percentile = Experiment.percentile_motion_range(
data, percentile, measure="speed", downsample_factor=downsample_factor,
)
print(f"Motion index at {percentile} percentile between 0:30 - 4:30 pre-trial is {motion_index_percentile}")
print(f"Speed at {percentile} percentile between 0:30 - 4:30 pre-trial is {speed_percentile}")
@dataclass
class SubjectPos:
filename: str | Path
times: np.ndarray
track: np.ndarray = field(repr=False)
motion_index: np.ndarray = field(repr=False)
cm_per_px: float | None = None
_point_marker: dict = field(default_factory=dict, init=False, repr=False)
@property
def min_x(self):
return np.min(self.track[:, 0])
@property
def min_y(self):
return np.min(self.track[:, 1])
@property
def max_x(self):
return np.max(self.track[:, 0])
@property
def max_y(self):
return np.max(self.track[:, 1])
@property
def duration(self) -> float:
if not len(self.times):
return 0
return (self.times[-1] - self.times[0]).item()
@classmethod
def parse_csv_track(
cls, filename: Path, subject_name: str = "mouse", frame_rate: float = 30., cm_per_px: float | None = None,
) -> "SubjectPos":
track = []
times = []
index = []
print(f"Reading {filename}")
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = list(next(reader))
last_t = None
last_frame_t = None
for i, row in enumerate(reader):
frame, instance, cx, cy, index_val, t1, t2 = row
if header[-1] == "real_timestamp_sec":
t = t2
elif header[-2] == "real_timestamp_sec":
t = t1
else:
raise ValueError("Cannot find the real_timestamp_sec column")
if instance != subject_name:
continue
track.append((float(cx), float(cy)))
index.append(float(index_val))
t = float(t)
frame_t = float(frame) / frame_rate
if not (frame_t * 0.95 <= t <= frame_t * 1.05):
raise ValueError(f"{filename} frame rate of {frame_rate} does not match time stamp at \"{row}\"")
if last_t is not None and last_t >= t or last_frame_t is not None and last_frame_t >= frame_t:
raise ValueError(f"{filename} sequential times are not increasing \"{row}\"")
last_t = t
last_frame_t = frame_t
times.append(frame_t)
times = np.array(times)
return SubjectPos(
filename=filename, times=np.array(times), track=np.array(track), motion_index=np.array(index),
cm_per_px=cm_per_px,
)
def extract_range(self, t_start: float | None = None, t_end: float | None = None) -> "SubjectPos":
if t_start is None:
t_start = self.times[0]
if t_end is None:
t_end = self.times[-1] + 1
i_s = np.sum(self.times < t_start)
i_e = np.sum(self.times <= t_end)
return SubjectPos(
filename=self.filename, times=self.times[i_s:i_e], track=self.track[i_s:i_e],
motion_index=self.motion_index[i_s:i_e], cm_per_px=self.cm_per_px,
)
@classmethod
def calculate_bouts(
cls, data: np.ndarray, threshold: float, above_threshold: bool, frame_rate: float,
refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0,
) -> list[tuple[float, float]]:
final_frame_rate = frame_rate / downsample_factor
refractory_n = int(round(refractory_period * final_frame_rate))
if refractory_period:
refractory_n = max(refractory_n, 1)
if len(data) <= 50:
return []
if downsample_factor != 1:
data = decimate(data, downsample_factor)
if above_threshold:
met_threshold = data >= threshold
else:
met_threshold = data <= threshold
last_marked = None
start_i = None
bouts = []
for i, state in enumerate(met_threshold):
if state:
last_marked = i
if start_i is None:
start_i = i
else:
continue
else:
if start_i is None:
continue
if not refractory_n:
bouts.append((start_i, i - start_i))
start_i = None
else:
if i - last_marked >= refractory_n:
bouts.append((start_i, last_marked - start_i + 1))
start_i = None
if start_i is not None:
bouts.append((start_i, last_marked - start_i + 1))
if min_bout_frames:
bouts = [b for b in bouts if b[1] >= min_bout_frames]
bouts = [(t / final_frame_rate, dur / final_frame_rate) for t, dur in bouts]
return bouts
def calculate_occupancy(
self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True,
) -> None:
n = self.track.shape[0]
if not n:
return
grid_width, grid_height = occupancy.shape
frame_proportion = 1
if frame_normalize:
frame_proportion = 1 / n
for i in range(n):
if self.track[i, 0] < 0 or self.track[i, 1] < 0:
continue
if pos_to_index is None:
x, y = self.track[i, :]
else:
x, y = pos_to_index(*self.track[i, :], grid_width, grid_height)
x = int(min(x, grid_width))
y = int(min(y, grid_height))
occupancy[x, y] += frame_proportion
def calculate_speed(self, pad_to_input: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
track = self.track
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0)
track = track[valid, :]
times = self.times[valid]
elapsed = times[1:] - times[:-1]
if np.any(elapsed <= 0):
raise ValueError(f"Found sequential time stamps whose difference is not positive {self}")
cm_per_px = self.cm_per_px or 1
dist = np.sqrt(np.sum(np.square(track[1:, :] - track[:-1, :]), axis=1, keepdims=False))
dist = dist * cm_per_px
speed = dist / elapsed
if pad_to_input:
return (
np.pad(times[1:], (1, 0), "edge"),
np.pad(speed, (1, 0), "edge"),
np.pad(dist, (1, 0), "edge"),
np.pad(elapsed, (1, 0), "edge")
)
return times[1:], speed, dist, elapsed
@classmethod
def plot_occupancy_data(
cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes,
gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True,
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True, cm_per_px: float = 1,
):
if gaussian_sigma:
occupancy = gaussian_filter(occupancy, gaussian_sigma)
if scale_to_one:
max_val = occupancy.max()
if max_val:
occupancy /= max_val
else:
occupancy[:] = 0
n_rows, n_cols = occupancy.T.shape
extent = (-0.5 * cm_per_px, (n_cols - 0.5) * cm_per_px, (n_rows - 0.5) * cm_per_px, -0.5 * cm_per_px)
im = ax.imshow(
occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc",
interpolation_stage="data", vmax=intensity_limit or None, extent=extent,
)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
if color_bar:
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
def plot_occupancy(
self, grid_width: int, grid_height: int,
pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True,
):
occupancy = np.zeros((grid_width, grid_height))
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
self.calculate_occupancy(occupancy, pos_to_index, frame_normalize)
cm_per_px = self.cm_per_px or 1
self.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar,
cm_per_px=cm_per_px,
)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion_index(
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Motion index",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
cm_per_px = self.cm_per_px or 1
motion_index = self.motion_index * cm_per_px
valid = motion_index >= 0
times = self.times[valid]
if len(times):
ax.plot(times / 60 - self.times[0] / 60, motion_index[valid], **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_speed(
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Speed",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
times, speed, _, _ = self.calculate_speed()
if len(times):
ax.plot(times / 60 - self.times[0] / 60, speed, **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def transform_to_categorical_pos(
self, position_to_side: Callable, categoricals: tuple[str, ...],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
track = self.track[:-1, :]
valid = np.logical_and(track[:, 0] >= 0, track[:, 1] >= 0)
pos_categoricals_name = [position_to_side(*p) for p in track[valid, :]]
categoricals = {name: i for i, name in enumerate(categoricals)}
categoricals_index = np.array([categoricals[n] for n in pos_categoricals_name])
time_diff = self.times[1:] - self.times[:-1]
assert np.all(time_diff >= 0)
time_diff = time_diff[valid]
times = self.times[:-1][valid]
return times, time_diff, categoricals_index
def plot_categorical_values(
self, times, categoricals_index, categoricals: tuple[str, ...],
fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
ax.plot(times / 60 - self.times[0] / 60, categoricals_index, "*", alpha=.3, ms=5)
ax.set_yticks(np.arange(len(categoricals)), categoricals)
ax.set_xlabel("Time (min)")
ax.set_ylabel("Side of box relative to teaball")
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_distance_from_point(
self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Distance",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0)
point = np.array(point_xy)[None, :]
distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1))
cm_per_px = self.cm_per_px or 1
distance = distance * cm_per_px
times = self.times[valid]
if len(times):
ax.plot(times / 60 - self.times[0] / 60, distance, **SubjectPos._point_marker)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def calculate_measure_bout(
self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float,
refractory_period: float = 0, downsample_factor: int = 1, min_bout_frames: int = 0,
) -> list[tuple[float, float]]:
match measure:
case "motion_index_freezing":
cm_per_px = self.cm_per_px or 1
index = self.motion_index * cm_per_px
above = False
case "motion_index_dart":
cm_per_px = self.cm_per_px or 1
index = self.motion_index * cm_per_px
above = True
case "speed_freezing":
times, index, _, time_diff = self.calculate_speed()
above = False
case "speed_dart":
times, index, _, time_diff = self.calculate_speed()
above = True
case _:
assert False
bouts = self.calculate_bouts(
index, threshold, above_threshold=above, frame_rate=frame_rate,
refractory_period=refractory_period,
downsample_factor=downsample_factor, min_bout_frames=min_bout_frames,
)
return bouts
def plot_bout(
self, measure: BOUT_MEASURE_TYPE, threshold: float, frame_rate: float, refractory_period: float = 0,
downsample_factor: int = 1, min_bout_frames: int = 0,
plot_type: Literal["scatter", "count", "mean", "duration"] = "scatter",
fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Duration (s)", norm_time: bool = False,
) -> list[tuple[float, float]]:
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
bouts = self.calculate_measure_bout(
measure, threshold, frame_rate, refractory_period, downsample_factor, min_bout_frames
)
match plot_type:
case "scatter":
if len(bouts):
ax.plot([b[0] / 60 for b in bouts], [b[1] for b in bouts], **SubjectPos._point_marker)
if len(self.times):
ax.set_xlim(0, (self.times[-1].item() - self.times[0].item()) / 60)
case "count":
ax.bar([""], [len(bouts) / (self.duration if norm_time else 1)])
case "mean":
val = sum(b[1] for b in bouts) / len(bouts) if bouts else 0
ax.bar([0], [val])
ax.plot([0] * len(bouts), [b[1] for b in bouts], "k.")
ax.set_xticks([0], [""])
case "duration":
norm_factor = (100 / self.duration) if norm_time else 1
ax.bar([0], [sum(b[1] for b in bouts) * norm_factor])
ax.plot([0] * len(bouts), np.cumsum([b[1] for b in bouts]) * norm_factor, "k.")
ax.set_xticks([0], [""])
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
return bouts
class ExperimentParameters:
root: Path
figure_root: Path
data_root: Path
video_root: Path
image_root: Path
metadata: list[str]
filename_fmt: str
experiment_triplet_name: tuple[str, str, str]
bouts_parameters: dict[str, dict[str, float]]
thresholds: dict[str, float]
frame_rate: float
times_of_interest: dict[str, float]
image_size: tuple[int, int]
output_fig_h: int
title_root_fmt: str
subject_title_fmt: str
video_filename_fmt: str
chopped_pre_times: list[tuple[int, int]]
occupancy_meas_args: dict
freezing_index_meas_args: dict
freezing_speed_meas_args: dict
# first element of cue, should be the label for the side of the cue
cue_categoricals: tuple[str, str]
freeze_categoricals: tuple[str, str]
experiment_groups: dict[str, tuple[str, ...]]
individual_plot_groups: tuple[str, ...]
multi_plot_groups: list[dict[str, tuple[str, ...]]]
def get_inventory_csv(self, batch: int | None = None) -> Path:
raise NotImplementedError
def get_tracked_data_root(self, batch: int | None = None) -> Path:
raise NotImplementedError
def read_all_experiments(self) -> list[Experiment]:
raise NotImplementedError
def get_cm_per_px(self, **selectors) -> float:
raise NotImplementedError
def get_cue_point(self, **selectors) -> tuple[float | str, float | str]:
raise NotImplementedError
def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]:
raise NotImplementedError
class YidanParameters(ExperimentParameters):
root = Path(r'H:\yidan_2025_behavior')
figure_root = root / "results" / "figures"
data_root = root / "results" / "data"
video_root = root / "results" / "video"
image_root = root / "results" / "video_images"
metadata = ["date", "subject", "strain", "condition", "cell_label", "exposure", "sex"]
filename_fmt = "{date}_{subject}_top_tracked.csv"
experiment_triplet_name = "Pre Trial", "Trial", "Post Trial"
bouts_parameters = {
"motion_index_freezing": {
"threshold": 0.001278, # %2.5
"refractory_period": 0.25,
"downsample_factor": 1,
"min_bout_frames": 2,
},
"speed_freezing": {
"threshold": 0.1180, # %5
"refractory_period": 0.25,
"downsample_factor": 1,
"min_bout_frames": 2,
},
"motion_index_dart": {
"threshold": 0.1242, # %90
"refractory_period": 0.8,
"downsample_factor": 6,
"min_bout_frames": 2,
},
"speed_dart": {
"threshold": 12.85, # %95
"refractory_period": 0.8,
"downsample_factor": 6,
"min_bout_frames": 2,
}
}
thresholds = {
"speed_freezing": 0.1180, # %5
"motion_index_freezing": 0.001278, # %2.5
"speed_dart": 14.32, # %95
"motion_index_dart": 0.1391, # %90
}
frame_rate = 24
times_of_interest = {
"pre_offset": 3 * 60,
"pre_duration": 2 * 60,
"trial_offset": 0,
"trial_duration": 2 * 60,
"post_offset": 0,
"post_duration": 2 * 60,
}
image_size = 960, 600
output_fig_h = image_size[1] // 2
output_fig_h = (output_fig_h // 2) * 2
title_root_fmt = "{subject}-{exposure}-{date}"
subject_title_fmt = "{subject}_exposure-{exposure}_{date}"
video_filename_fmt = "{date}_{subject}_top.mp4"
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(10)]
occupancy_meas_args = {"split_horizontally": True}
freezing_index_meas_args = {"threshold": thresholds["motion_index_freezing"], "freeze_name": "Freezing"}
freezing_speed_meas_args = {"threshold": thresholds["speed_freezing"], "freeze_name": "Freezing"}
cue_categoricals = "Cue-side", "Far-side"
freeze_categoricals = "Freezing", "Non-freezing"
experiment_groups = {
"condition": ("Blank", "TMT", "2MBA", "IAMM"),
"strain": ("TRAP1",),
"exposure": ("1", "2"),
"sex": ("M", "F"),
}
individual_plot_groups = "condition", "strain", "exposure"
multi_plot_groups = [
{
"individual_plot_per": ("exposure", "strain"),
"groups_in_single_plot": ("condition", ),
},
{
"individual_plot_per": ("condition", "strain"),
"groups_in_single_plot": ("exposure", ),
},
{
"individual_plot_per": ("condition", "strain"),
"groups_in_single_plot": ("exposure", "sex"),
},
]
def get_inventory_csv(self, batch: int | None = None) -> Path:
if batch is None:
raise NotImplementedError
return self.root / f"behavioral data timestamp - b{batch} timestamps.csv"
def get_tracked_data_root(self, batch: int | None = None) -> Path:
if batch is None:
raise NotImplementedError
return self.root / "tracking" / f"batch{batch}"
def read_all_experiments(self) -> list[Experiment]:
all_experiments = []
for num in (3, 5, 6, 7):
csv_experiment_times = self.root / f"behavioral data timestamp - b{num} timestamps.csv"
tracking_data_root = self.root / "tracking" / f"batch{num}"
json_data_root = self.root / "tracking" / f"batch{num}" / "json"
experiments = Experiment.read_experiments_from_csv_inventory(
csv_experiment_times, self.metadata, self.filename_fmt, self.title_root_fmt
)
all_experiments.extend(experiments)
for experiment in experiments:
if num == 3:
experiment.parse_box_metadata(json_data_root, cm_per_px=self.get_cm_per_px(**experiment.metadata))
else:
experiment.set_box_metadata(
*self.get_box_size_cc(**experiment.metadata), image_size=self.image_size,
cm_per_px=self.get_cm_per_px(**experiment.metadata),
)
experiment.read_pos_track(
tracking_data_root, **self.times_of_interest, frame_rate=self.frame_rate,
)
experiment.chop_pos_track(hab_segments=self.chopped_pre_times)
experiment.cue_point = self.get_cue_point(**experiment.metadata)
# all experiments share the same scaling, even if slightly different sizes
Experiment.enlarge_canvas(all_experiments)
return all_experiments
def get_cm_per_px(self, **selectors) -> float:
return 20 / 855.1 * 2
def get_cue_point(self, **selectors) -> tuple[float | str, float | str]:
return "left", "bottom"
def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]:
return 200, 132, 784, 486
class CynthiaParameters(ExperimentParameters):
root = Path(r'H:\cynthia')
figure_root = root / "results" / "figures"
data_root = root / "results" / "data"
video_root = root / "results" / "video"
image_root = root / "results" / "video_images"
metadata = ["date", "subject", "stage", "condition", "cue", "sex"]
filename_fmt = "{subject}_{stage}_{date}_tracked.csv"
experiment_triplet_name = "Pre Trial", "Trial", "Post Trial"
bouts_parameters = {
"motion_index_freezing": {
"threshold": 0.001278, # %2.5
"refractory_period": 0.25,
"downsample_factor": 1,
"min_bout_frames": 2,
},
"speed_freezing": {
"threshold": 0.1180, # %5
"refractory_period": 0.25,
"downsample_factor": 1,
"min_bout_frames": 2,
},
"motion_index_dart": {
"threshold": 0.1242, # %90
"refractory_period": 0.8,
"downsample_factor": 6,
"min_bout_frames": 2,
},
"speed_dart": {
"threshold": 12.85, # %95
"refractory_period": 0.8,
"downsample_factor": 6,
"min_bout_frames": 2,
}
}
thresholds = {
"speed_freezing": 0.1180, # %5
"motion_index_freezing": 0.001278, # %2.5
"speed_dart": 14.32, # %95
"motion_index_dart": 0.1391, # %90
}
frame_rate = 24
times_of_interest = {
"pre_duration": -28,
"post_duration": 28
}
image_size = 640, 480
output_fig_h = image_size[1] // 2
output_fig_h = (output_fig_h // 2) * 2
title_root_fmt = "{subject}-{stage}-{condition}-{cue}"
subject_title_fmt = "{subject}_{stage}_{condition}_{cue}"
video_filename_fmt = "{subject}_{stage}_{date}.mp4"
chopped_pre_times = [(i * 30, (i + 1) * 30) for i in range(8)]
occupancy_meas_args = {"split_horizontally": False}
freezing_index_meas_args = {"threshold": thresholds["motion_index_freezing"], "freeze_name": "Freezing"}
freezing_speed_meas_args = {"threshold": thresholds["speed_freezing"], "freeze_name": "Freezing"}
cue_categoricals = "Cue-side", "Far-side"
freeze_categoricals = "Freezing", "Non-freezing"
experiment_groups = {
"condition": ("cued", "control", "back"),
"cue": ("odor",),
"stage": ("test", "test_24hr"),
}
individual_plot_groups = "condition", "cue", "stage"
multi_plot_groups = [
{
"individual_plot_per": ("condition", "cue"),
"groups_in_single_plot": ("stage", ),
},
{
"individual_plot_per": ("stage", "cue"),
"groups_in_single_plot": ("condition", ),
},
{
"individual_plot_per": ("cue", ),
"groups_in_single_plot": ("condition", "stage"),
},
]
def __init__(self, batch_names: list[str], cue_in_corner: bool, test_is_train_chamber: bool):
self.cue_in_corner = cue_in_corner
self.test_is_train_chamber = test_is_train_chamber
self.batch_names = batch_names
def get_inventory_csv(self, batch: int | None = None) -> Path:
if batch is None:
return self.root / f"{self.batch_names[0]}.csv"
return self.root / f"{self.batch_names[batch]}.csv"
def get_tracked_data_root(self, batch: int | None = None) -> Path:
if batch is None:
return self.root / self.batch_names[0]
return self.root / self.batch_names[batch]
def read_all_experiments(self) -> list[Experiment]:
all_experiments = []
for num in range(len(self.batch_names)):
experiments = Experiment.read_experiments_from_csv_inventory(
self.get_inventory_csv(num), self.metadata, self.filename_fmt, self.title_root_fmt
)
all_experiments.extend(experiments)
for experiment in experiments:
experiment.set_box_metadata(
*self.get_box_size_cc(**experiment.metadata), image_size=self.image_size,
cm_per_px=self.get_cm_per_px(**experiment.metadata),
)
experiment.read_pos_track(
self.get_tracked_data_root(num), **self.times_of_interest, frame_rate=self.frame_rate,
)
experiment.chop_pos_track(hab_segments=self.chopped_pre_times)
experiment.cue_point = self.get_cue_point(**experiment.metadata)
if self.test_is_train_chamber:
Experiment.enlarge_canvas(all_experiments)
else:
# align experiments that share the same scaling, even if slightly different sizes
for group in ["train", "test"]:
experiments = [e for e in all_experiments if e.metadata["stage"] == group]
if experiments:
Experiment.enlarge_canvas(experiments)
return all_experiments
def get_cm_per_px(self, **selectors) -> float:
if self.test_is_train_chamber:
return 15 / 258.86
if selectors["stage"] == "train":
return 15 / 258.86
return 15 / 200.01
def get_cue_point(self, **selectors) -> tuple[float | str, float | str]:
if self.test_is_train_chamber:
if self.cue_in_corner:
return 207, 52
return 345, 80
if self.cue_in_corner:
if selectors["stage"] == "train":
return 207, 52
return 446, 429
if selectors["stage"] == "train":
return 345, 80
return 330, 430
def get_box_size_cc(self, **selectors) -> tuple[float, float, float, float]:
# box sizes must be similar if averaging occupancy across them
train = 74, 32, 593, 480
test = 133, 0, 550, 466
if self.test_is_train_chamber:
return train
if selectors["stage"] == "train":
return train
return test
def run_exports(
parameters: ExperimentParameters, export_csv: bool = True, export_subjects: bool = True,
export_groups: bool = True, export_multi_groups: bool = True, cat_unit: Literal["times", "percent"]="times",
):
SubjectPos._point_marker = {"marker": ".", "markersize": 1.5, "linestyle": ""}
Experiment.triplet_name = parameters.experiment_triplet_name
experiments = parameters.read_all_experiments()
if export_csv:
Experiment.export_experiments_csv(
experiments, parameters.data_root, parameters.occupancy_meas_args, parameters.cue_categoricals,
parameters.freezing_index_meas_args, parameters.freezing_speed_meas_args, parameters.freeze_categoricals,
cat_unit, True, parameters.frame_rate, parameters.bouts_parameters,
)
if export_subjects:
Experiment.export_single_subject_figures(
experiments, parameters.figure_root, parameters.occupancy_meas_args, cat_unit,
parameters.subject_title_fmt, parameters.frame_rate, parameters.bouts_parameters,
)
individual_group_names = parameters.individual_plot_groups
if export_groups:
individual_group_items = [parameters.experiment_groups[g] for g in individual_group_names]
for group_values in product(*individual_group_items):
named_values = {name: value for name, value in zip(individual_group_names, group_values)}
experiments_ = Experiment.filter(experiments, **named_values)
label = "-".join(group_values)
Experiment.export_multi_experiment_figures(
experiments_, parameters.figure_root, parameters.occupancy_meas_args,
cat_unit, label, parameters.cue_categoricals, parameters.freezing_index_meas_args,
parameters.freezing_speed_meas_args, parameters.freeze_categoricals, parameters.frame_rate,
parameters.bouts_parameters, None, "", "grouped",
)
if export_multi_groups:
items = []
for split in parameters.multi_plot_groups:
individual_plot_names = split["individual_plot_per"]
groups_in_single_plot_names = split["groups_in_single_plot"]
individual_plot_values = [parameters.experiment_groups[n] for n in individual_plot_names]
groups_in_single_plot_values = [parameters.experiment_groups[n] for n in groups_in_single_plot_names]
filter_args = []
for group_values in product(*groups_in_single_plot_values):
named_values = {name: value for name, value in zip(groups_in_single_plot_names, group_values)}
filter_args.append(named_values)
for group_values in product(*individual_plot_values):
named_values = {name: value for name, value in zip(individual_plot_names, group_values)}
items.append((
named_values,
filter_args,
"-".join(map(str, group_values)),
"-".join(groups_in_single_plot_names),
))
for groups, filter_args, label, sub_dir in items:
experiments_ = Experiment.filter(experiments, **groups)
bracket_group_names = ["{{{" + f + "}}}" for f in filter_args[0].keys()]
group_label = "$\\bf" + "$\n$\\bf".join(bracket_group_names) + "$\n\n"
Experiment.export_multi_experiment_figures(
experiments_, parameters.figure_root, parameters.occupancy_meas_args,
cat_unit, label, parameters.cue_categoricals, parameters.freezing_index_meas_args,
parameters.freezing_speed_meas_args, parameters.freeze_categoricals, parameters.frame_rate,
parameters.bouts_parameters, filter_args, group_label, f"grouped_single_figure/by_{sub_dir}"
)
def generate_video(
parameters: ExperimentParameters, batch: int | None, time_window_padding: float, ffmpeg: str | Path,
temp_images: bool, above_measure: Literal["motion_index", "speed"] | None,
below_measure: Literal["motion_index", "speed"] | None,
**selectors: str,
):
Experiment.triplet_name = parameters.experiment_triplet_name
csv_experiment_times = parameters.get_inventory_csv(batch)
tracking_data_root = parameters.get_tracked_data_root(batch)
experiment = Experiment.read_experiment_from_csv_inventory(
csv_experiment_times, parameters.metadata, parameters.filename_fmt,
parameters.title_root_fmt, **selectors,
)
experiment.set_box_metadata(
*parameters.get_box_size_cc(**experiment.metadata), image_size=parameters.image_size,
cm_per_px=parameters.get_cm_per_px(**experiment.metadata),
)
experiment.read_pos_track(
tracking_data_root, **parameters.times_of_interest, frame_rate=parameters.frame_rate,
)
video_filename = tracking_data_root / parameters.video_filename_fmt.format(**experiment.metadata)
experiment.merge_video_with_images(
parameters, parameters.frame_rate, parameters.image_size, parameters.output_fig_h,
time_window_padding, video_filename, ffmpeg, temp_images, above_measure, below_measure,
parameters.bouts_parameters, parameters.thresholds,
)
if __name__ == "__main__":
sample_parameters = CynthiaParameters(
batch_names=["111025"], cue_in_corner=True, test_is_train_chamber=False,
)
# sample_parameters = YidanParameters()
# Experiment.print_measure_hist(
# sample_parameters.read_all_experiments(), percentile=90, downsample_factor=1, hab_segment=(30, 270)
# )
run_exports(
sample_parameters, export_subjects=True, export_csv=False, export_groups=False, export_multi_groups=False,
# cat_unit="percent",
)
# generate_video(
# sample_parameters, 5, 3,
# r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe",
# False, "motion_index", "speed", date="20241206", subject="1202",
# )
# generate_video(
# sample_parameters, 6, 3,
# r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe",
# False, "motion_index", "speed", date="20250423", subject="1282",
# )
# generate_video(
# sample_parameters, 6, 3,
# r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe",
# False, "motion_index", "speed", date="20250321", subject="1283",
# )
# generate_video(
# sample_parameters, 0, 3,
# r"H:\yidan_2025_behavior\ffmpeg\bin\ffmpeg.exe",
# False, "motion_index", "speed", date="20251024", subject="2831",
# )
# pos_data = SubjectPos.parse_csv_track(
# Path(r"D:\code_data\yidan_2025\20251006_fake_mouse_top_tracked.csv"), frame_rate=sample_parameters.frame_rate,
# subject_name="mouse", cm_per_px=sample_parameters.cm_per_px,
# )
# exp = Experiment(0, 0, 0, 0, 0, 0, "", "", cm_per_px=sample_parameters.cm_per_px)
# exp.pos_data = pos_data
# Experiment.print_measure_hist(
# [exp], percentile=90, downsample_factor=1, hab_segment=(30, 270)
# )
# Experiment.plot_compare_motion_index_range(
# Path(r"D:\code_data\cynthia_sp_2025\1303_test_20250506_tracked.csv"),
# Path(r"D:\code_data\cynthia_sp_2025\fakemouse_test_20250604_tracked.csv"),
# "Real mouse",
# "Fake mouse",
# 15,
# measure="speed"
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment