Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active August 25, 2025 06:24
Show Gist options
  • Select an option

  • Save avivajpeyi/a4df641a4dc9752e3dee2be0fc6774a9 to your computer and use it in GitHub Desktop.

Select an option

Save avivajpeyi/a4df641a4dc9752e3dee2be0fc6774a9 to your computer and use it in GitHub Desktop.
psd_comparison.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNi8GHsKTBA19rbfOLzTj+2",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/a4df641a4dc9752e3dee2be0fc6774a9/psd_comparison.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# PSD comparisons\n",
"\n",
"https://github.com/NirGutt/tPowerBilby/blob/main/tPowerBilby/Utils/PostProcessingtPowerBilby.py#L202C1-L215C25\n",
"\n",
"https://github.com/NirGutt/tPowerBilby/blob/268548c670939314b7fab81a316aeb145811a5b9/tPowerBilby/Utils/asd_data_manipulations.py#L241"
],
"metadata": {
"id": "vKv1BNLmGG75"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bdX46NZzoEI4",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ec5a55aa-7ad5-4eed-d57b-1cf07c7b9cac"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m30.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m57.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m31.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m315.5/315.5 kB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.2/108.2 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.8/43.8 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m131.1/131.1 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.4/47.4 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m113.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"pyopenssl 24.2.1 requires cryptography<44,>=41.0.5, but you have cryptography 45.0.6 which is incompatible.\n",
"pydrive2 1.21.3 requires cryptography<44, but you have cryptography 45.0.6 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"source": [
"! pip install bilby gwpy -q"
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import bilby\n",
"from gwosc.datasets import event_gps\n",
"from gwpy.timeseries import TimeSeries\n",
"import h5py\n",
"from scipy.stats import anderson, kstest, norm\n",
"import os\n",
"import h5py\n",
"import urllib.request\n",
"\n",
"\n",
"\n",
"def get_GW_data_asd_welch(psd_end_time, det='L1', duration=4, f_i=20, f_f=896, method='welch', n_step_back=32, roll_off=0.4):\n",
" psd_duration = n_step_back * duration\n",
" psd_start_time = psd_end_time - psd_duration\n",
" psd_data = TimeSeries.fetch_open_data(det, psd_start_time, psd_end_time)\n",
"\n",
" psd_alpha = 2 * roll_off / duration\n",
" psd = psd_data.psd(\n",
" fftlength=duration, window=(\"tukey\", psd_alpha), method=method\n",
" )\n",
"\n",
" freqs = psd.frequencies.value\n",
" vals = np.sqrt(psd.value)\n",
" I = (freqs >= f_i) & (freqs <= f_f)\n",
" return freqs[I], vals[I]\n",
"\n",
"\n",
"def get_GW_data(psd_end_time, psd_duration=4, det='H1', f_i=20, f_f=896, roll_off=0.2, return_raw_data=False):\n",
" psd_start_time = psd_end_time - psd_duration\n",
" psd_data = TimeSeries.fetch_open_data(det, psd_start_time, psd_end_time)\n",
" ifo = bilby.gw.detector.get_empty_interferometer(det)\n",
" ifo.strain_data.roll_off = roll_off\n",
" ifo.maximum_frequency = f_f\n",
" ifo.minimum_frequency = f_i\n",
" ifo.strain_data.set_from_gwpy_timeseries(psd_data)\n",
"\n",
" x = ifo.strain_data.frequency_array\n",
" y = ifo.strain_data.frequency_domain_strain\n",
" Ew = np.sqrt(ifo.strain_data.window_factor)\n",
"\n",
" I = (x >= f_i) & (x <= f_f)\n",
" if return_raw_data:\n",
" return x[I], y[I] / Ew\n",
" return x[I], np.abs(y[I]) / Ew\n",
"\n",
"\n",
"\n",
"def get_GWTC_psd(url,event_name, det='H1', f_i=20, f_f=896):\n",
" # Download if not already present\n",
" local_path=f'GWTC_{event_name}.h5'\n",
" if not os.path.exists(local_path):\n",
" print(f\"Downloading {url} to {local_path} ...\")\n",
" urllib.request.urlretrieve(url, local_path)\n",
" print(\"Download complete.\")\n",
"\n",
" # Open the HDF5 file\n",
" ret_dict = {}\n",
" with h5py.File(local_path, 'r') as f:\n",
" psd_data = f['C01:IMRPhenomXPHM']['psds']\n",
" for ifo_name in list(psd_data):\n",
" freqs = psd_data[ifo_name][:, 0]\n",
" vals = psd_data[ifo_name][:, 1]\n",
" I = (freqs >= f_i) & (freqs <= f_f)\n",
" ret_dict[ifo_name] = (freqs[I], np.sqrt(vals[I]))\n",
"\n",
" return ret_dict\n",
"\n",
"\n",
"def get_pval(x_data, y_data, x_psd, y_psd):\n",
" # Keep only the data points where the frequencies exactly match\n",
" I_keep = np.isin(x_data, x_psd)\n",
"\n",
" # Make sure the PSD values are aligned\n",
" common_freqs = x_data[I_keep]\n",
" y_psd_matched = y_psd[np.isin(x_psd, common_freqs)]\n",
"\n",
" # Compute the normalized ratio\n",
" ratio_real = np.real(y_data[I_keep] / y_psd_matched)\n",
" ratio_imag = np.imag(y_data[I_keep] / y_psd_matched)\n",
"\n",
" # Combine real and imaginary parts\n",
" ratio_combined = np.concatenate([ratio_real, ratio_imag])\n",
"\n",
" # Anderson-Darling and KS test\n",
" a2_stat = anderson(ratio_combined).statistic\n",
" pvalue = kstest(ratio_combined, norm.cdf).pvalue\n",
"\n",
" return pvalue\n",
"\n"
],
"metadata": {
"id": "JYB5sgfbozMc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## GW150914"
],
"metadata": {
"id": "mi_jSuk1B3_l"
}
},
{
"cell_type": "code",
"source": [
"event_name = 'GW150914'\n",
"psd_end_time = event_gps(event_name) - 3\n",
"\n",
"# # Fetch PSDs\n",
"x_data, y_data = get_GW_data(psd_end_time, f_i=20, f_f=896, return_raw_data=True)\n",
"x_welch, y_welch = get_GW_data_asd_welch(psd_end_time, det=\"H1\", duration=4, f_i=20, f_f=896, method='welch')\n",
"\n",
"# GWTC PSD\n",
"url = 'https://zenodo.org/records/6513631/files/IGWN-GWTC2p1-v2-GW150914_095045_PEDataRelease_mixed_cosmo.h5?download=1'\n",
"gwtc_psd = get_GWTC_psd(url, event_name, det='H1', f_i=20, f_f=896)\n",
"x_gwtc, y_gwtc = gwtc_psd['H1']"
],
"metadata": {
"id": "mQ-nT8tfB5m7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "08bebc14-23b7-4203-9878-6c60b3928484"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading https://zenodo.org/records/6513631/files/IGWN-GWTC2p1-v2-GW150914_095045_PEDataRelease_mixed_cosmo.h5?download=1 to GWTC_GW150914.h5 ...\n",
"Download complete.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Compute p-value\n",
"pval_welch = get_pval(x_data, y_data, x_welch, y_welch)\n",
"pval_gwtc = get_pval(x_data, y_data, x_gwtc, y_gwtc)\n",
"\n",
"# Plot\n",
"plt.loglog(x_data, np.abs(y_data)**2, label='Data', alpha=0.5, color='lightgray')\n",
"plt.loglog(x_welch, y_welch**2, label=f'Welch (p-val={pval_welch:.3f})')\n",
"plt.loglog(x_gwtc, y_gwtc **2, label=f'GWTC (p-val={pval_gwtc:.3f})')\n",
"plt.xlabel(\"Frequency [Hz]\")\n",
"plt.ylabel(\"PSD [Hz$^{-1}$]\")\n",
"plt.grid(False)\n",
"plt.legend()\n",
"plt.title(\"GW150914\")\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 483
},
"id": "Yhm1w6LsrRKY",
"outputId": "7ebf303b-fd81-47df-b0f4-b3cc2f305e20"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## GW190521"
],
"metadata": {
"id": "hySBJNSmCVXy"
}
},
{
"cell_type": "code",
"source": [
"event_name = 'GW190521'\n",
"psd_end_time = event_gps(event_name)\n",
"\n",
"# # Fetch PSDs\n",
"x_data, y_data = get_GW_data(psd_end_time, f_i=20, f_f=896, return_raw_data=True)\n",
"x_welch, y_welch = get_GW_data_asd_welch(psd_end_time, det=\"H1\", duration=4, f_i=20, f_f=896, method='welch')\n",
"\n",
"# GWTC PSD\n",
"url = 'https://zenodo.org/records/6513631/files/IGWN-GWTC2p1-v2-GW190521_030229_PEDataRelease_mixed_cosmo.h5?download=1'\n",
"gwtc_psd = get_GWTC_psd(url, event_name, det='H1', f_i=20, f_f=225)\n",
"x_gwtc, y_gwtc = gwtc_psd['H1']\n"
],
"metadata": {
"id": "IH9AjSDyCRfe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute p-value\n",
"pval_welch = get_pval(x_data, y_data, x_welch, y_welch)\n",
"pval_gwtc = get_pval(x_data, y_data, x_gwtc, y_gwtc)\n",
"\n",
"# Plot\n",
"plt.loglog(x_data, np.abs(y_data), label='Data', alpha=0.5, color='lightgray')\n",
"plt.loglog(x_welch, y_welch, label=f'Welch (p-val={pval_welch:.3f})')\n",
"plt.loglog(x_gwtc, y_gwtc , label=f'GWTC (p-val={pval_gwtc:.3f})')\n",
"plt.xlabel(\"Frequency [Hz]\")\n",
"plt.ylabel(\"ASD [Hz$^{-1/2}$]\")\n",
"plt.grid(False)\n",
"plt.legend()\n",
"plt.title(\"GW190521\")\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 483
},
"id": "kVFhnadkCX5A",
"outputId": "c0f6c424-5f82-46d5-f110-a4c0b7748ed6"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Grabbing PSDs from OzStar\n",
"\n"
],
"metadata": {
"id": "8CdXHUiOxOVB"
}
},
{
"cell_type": "code",
"source": [
"import glob\n",
"import re\n",
"import os\n",
"import h5py\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from tqdm.auto import tqdm\n",
"from typing import Dict, List, Tuple\n",
"import ast\n",
"from gwpy.timeseries import TimeSeries\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import bilby\n",
"from gwosc.datasets import event_gps\n",
"from gwpy.timeseries import TimeSeries\n",
"import h5py\n",
"from scipy.stats import anderson, kstest, norm\n",
"import os\n",
"import h5py\n",
"import urllib.request\n",
"\n",
"\n",
"def get_event_name(path: str) -> str:\n",
" \"\"\"Extract GW event name (e.g., GW150914_095045) from filename.\"\"\"\n",
" match = re.search(r\"(GW\\d{6}_\\d{6})\", path)\n",
" if not match:\n",
" raise ValueError(f\"Could not extract event name from {path}\")\n",
" return match.group(1)\n",
"\n",
"\n",
"def find_analysis_group(fin: h5py.File) -> str:\n",
" \"\"\"Return preferred analysis group: default 'C01:IMRPhenomXPHM', else first with 'psds' and 'config_file'.\"\"\"\n",
" if \"C01:IMRPhenomXPHM\" in fin and \"psds\" in fin[\"C01:IMRPhenomXPHM\"]:\n",
" return \"C01:IMRPhenomXPHM\"\n",
" for k in fin.keys():\n",
" if isinstance(fin[k], h5py.Group):\n",
" group = fin[k]\n",
" has_psds = \"psds\" in group\n",
" has_config = \"config_file\" in group or any(\"config\" in subkey for subkey in group.keys())\n",
" if has_psds and has_config:\n",
" return k\n",
" raise KeyError(\"No group with both 'psds' and config data found in this file\")\n",
"\n",
"\n",
"def find_psd_group(fin: h5py.File) -> str:\n",
" \"\"\"Return preferred PSD group: default 'C01:IMRPhenomXPHM', else first with 'psds'.\"\"\"\n",
" if \"C01:IMRPhenomXPHM\" in fin and \"psds\" in fin[\"C01:IMRPhenomXPHM\"]:\n",
" return \"C01:IMRPhenomXPHM\"\n",
" for k in fin.keys():\n",
" if isinstance(fin[k], h5py.Group) and \"psds\" in fin[k]:\n",
" return k\n",
" raise KeyError(\"No group with 'psds' found in this file\")\n",
"\n",
"\n",
"def get_pe_paths() -> Dict[str, str]:\n",
" \"\"\"Get paths to parameter estimation data files.\"\"\"\n",
" dirs = [\n",
" \"/datasets/LIGO/public/gwosc.osgstorage.org/gwdata/zenodo/ligo-virgo-kagra/2021/5546662/1\",\n",
" \"/datasets/LIGO/public/gwosc.osgstorage.org/gwdata/zenodo/ligo-virgo-kagra/2022/5117702/v2\",\n",
" ]\n",
" files = []\n",
" for d in dirs:\n",
" files.extend(glob.glob(f\"{d}/*PEDataRelease_mixed_cosmo.h5\"))\n",
" paths = {get_event_name(f): f for f in files}\n",
" return paths\n",
"\n",
"\n",
"def extract_psds(outdir: str, fmin: float = 20.0, fmax: float = 2048.0) -> str:\n",
" \"\"\"Extract PSDs from GWOSC PEDataRelease files and save them into one HDF5 file.\"\"\"\n",
" paths = get_pe_paths()\n",
" os.makedirs(outdir, exist_ok=True)\n",
" outfn = os.path.join(outdir, \"GWTC3_psds.h5\")\n",
"\n",
" with h5py.File(outfn, \"w\") as fout:\n",
" for event, fpath in tqdm(paths.items(), desc=\"Extracting PSDs\"):\n",
" with h5py.File(fpath, \"r\") as fin:\n",
" try:\n",
" psd_group_name = find_psd_group(fin)\n",
" except KeyError:\n",
" print(f\"[WARN] No psds group found in {fpath}\")\n",
" continue\n",
"\n",
" psd_data = fin[psd_group_name][\"psds\"]\n",
" g = fout.create_group(event)\n",
" g.attrs[\"psd_group\"] = psd_group_name\n",
"\n",
" for ifo_name in psd_data:\n",
" if ifo_name in ['V1', \"K1\"]:\n",
" continue # skip Virgo/KAGRA\n",
"\n",
" freqs = psd_data[ifo_name][:, 0]\n",
" vals = psd_data[ifo_name][:, 1]\n",
" mask = (freqs >= fmin) & (freqs <= fmax)\n",
" freqs, vals = freqs[mask], vals[mask]\n",
"\n",
" ifo_group = g.create_group(ifo_name)\n",
" ifo_group.create_dataset(\"freqs\", data=freqs)\n",
" ifo_group.create_dataset(\"psd\", data=vals)\n",
"\n",
" return outfn\n",
"\n",
"\n",
"def get_specific_configs(file_path: str, h5_config_group_path: str, keys_to_retrieve: List[str]) -> Dict:\n",
" configs_dict = {}\n",
"\n",
" try:\n",
" if not os.path.exists(file_path):\n",
" print(f\"Error: The file '{file_path}' was not found.\")\n",
" return configs_dict\n",
"\n",
" with h5py.File(file_path, 'r') as f:\n",
" if h5_config_group_path not in f:\n",
" print(f\"Error: The HDF5 path '{h5_config_group_path}' was not found.\")\n",
" return configs_dict\n",
"\n",
" config_group = f[h5_config_group_path]\n",
"\n",
" for key in keys_to_retrieve:\n",
" try:\n",
" obj = config_group[key]\n",
"\n",
" if isinstance(obj, h5py.Dataset):\n",
" content = obj[()]\n",
"\n",
" # Handle numpy arrays containing byte strings\n",
" if isinstance(content, np.ndarray):\n",
" # Extract the first (and usually only) element from the array\n",
" if content.size > 0:\n",
" content = content.item() # Extract scalar from array\n",
" else:\n",
" print(f\"Warning: Empty array for key '{key}'\")\n",
" continue\n",
"\n",
" # Handle bytes by decoding first\n",
" if isinstance(content, bytes):\n",
" content = content.decode('utf-8')\n",
"\n",
" # Try to evaluate as Python literal (dict, list, etc.)\n",
" try:\n",
" processed_value = ast.literal_eval(content)\n",
" except (ValueError, SyntaxError) as e:\n",
" print(f\"Info: Storing '{key}' as string (couldn't parse as literal)\")\n",
" # Store as string if it can't be parsed as a Python literal\n",
" processed_value = content\n",
"\n",
" configs_dict[key] = processed_value\n",
"\n",
" else:\n",
" print(f\"Warning: '{key}' is not a dataset and was skipped.\")\n",
"\n",
" except KeyError:\n",
" print(f\"Warning: Key '{key}' not found in the HDF5 group.\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Warning: Error processing '{key}': {e}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"An unexpected error occurred: {e}\")\n",
"\n",
" return configs_dict\n",
"\n",
"\n",
"def get_gw_event_configs(fpath: str, analysis_group: str = None) -> Dict:\n",
" \"\"\"Get configuration parameters for a gravitational wave event.\"\"\"\n",
" with h5py.File(fpath, 'r') as fin:\n",
" if analysis_group is None:\n",
" analysis_group = find_analysis_group(fin)\n",
"\n",
" # Try different possible config paths\n",
" possible_config_paths = [\n",
" f'{analysis_group}/config_file/config',\n",
" f'{analysis_group}/config',\n",
" f'{analysis_group}/configuration'\n",
" ]\n",
"\n",
" h5_config_group = None\n",
" for path in possible_config_paths:\n",
" if path in fin:\n",
" h5_config_group = path\n",
" break\n",
"\n",
" if h5_config_group is None:\n",
" raise KeyError(f\"No config group found in {analysis_group}\")\n",
"\n",
" desired_keys = [\n",
" 'channel-dict', 'deltaT', 'duration', 'maximum-frequency',\n",
" 'minimum-frequency', 'psd-fractional-overlap', 'psd-length',\n",
" 'psd-maximum-duration', 'psd-method', 'trigger-time', 'tukey-roll-off'\n",
" ]\n",
"\n",
" configs = get_specific_configs(fpath, h5_config_group, desired_keys)\n",
" configs['analysis_group'] = analysis_group # Store which group was used\n",
"\n",
" # Helper function to safely extract scalar values\n",
" def extract_scalar(value):\n",
" \"\"\"Extract scalar value from numpy array or return as-is if already scalar.\"\"\"\n",
" if hasattr(value, 'item'): # numpy array\n",
" return value.item()\n",
" elif hasattr(value, '__len__') and len(value) == 1: # single-element sequence\n",
" return value[0]\n",
" else:\n",
" return value\n",
"\n",
" # Calculate timing parameters if we have the required keys\n",
" if all(key in configs for key in ['trigger-time', 'deltaT', 'duration']):\n",
" trigger_time = float(extract_scalar(configs['trigger-time']))\n",
" delta_t = float(extract_scalar(configs['deltaT']))\n",
" duration = float(extract_scalar(configs['duration']))\n",
"\n",
" end_time = trigger_time + delta_t\n",
" start_time = end_time - duration\n",
" psd_end_time = start_time\n",
"\n",
" # Calculate PSD duration with safety check\n",
" psd_max_duration_raw = configs.get('psd-maximum-duration', 32 * duration)\n",
" psd_max_duration = float(extract_scalar(psd_max_duration_raw))\n",
" psd_duration = min(32 * duration, psd_max_duration)\n",
" psd_start_time = psd_end_time - psd_duration\n",
"\n",
" # Add calculated times to configs\n",
" configs['analysis_start_time'] = start_time\n",
" configs['analysis_end_time'] = end_time\n",
" configs['psd_start_time'] = psd_start_time\n",
" configs['psd_end_time'] = psd_end_time\n",
" configs['psd_duration'] = psd_duration\n",
" configs['postevent_start_time'] = end_time\n",
" configs['postevent_end_time'] = end_time + duration\n",
"\n",
" return configs\n",
"\n",
"\n",
"def _get_data_files_and_gps_times(det: str = \"L1\") -> Dict[int, str]:\n",
" search_str = f\"/datasets/LIGO/public/gwosc.osgstorage.org/gwdata/O3b/strain.4k/hdf.v1/{det}/*/*.hdf5\"\n",
" files = glob.glob(search_str)\n",
"\n",
" if not files:\n",
" raise FileNotFoundError(f\"No HDF5 files found at {search_str}\")\n",
"\n",
" path_dict = {}\n",
" for f in files:\n",
" match = re.search(r\"R1-(\\d+)-\\d+\\.hdf5\", f)\n",
" if match:\n",
" gps_start = int(match.group(1))\n",
" path_dict[gps_start] = f\n",
"\n",
" return dict(sorted(path_dict.items()))\n",
"\n",
"\n",
"def get_data_dicts() -> Dict[str, Dict[int, str]]:\n",
" return {\n",
" \"L1\": _get_data_files_and_gps_times(\"L1\"),\n",
" \"H1\": _get_data_files_and_gps_times(\"H1\"),\n",
" }\n",
"\n",
"\n",
"def get_fnames_for_range(gps_start: float, gps_end: float, det: str = \"L1\") -> List[str]:\n",
" gps_start = int(gps_start)\n",
" gps_end = int(gps_end)\n",
"\n",
" gps_files = _get_data_files_and_gps_times(det)\n",
" start_times = sorted(gps_files.keys())\n",
"\n",
" files = []\n",
"\n",
" for i in range(len(start_times)):\n",
" t0 = start_times[i]\n",
" t1 = start_times[i + 1] if i + 1 < len(start_times) else float('inf')\n",
"\n",
" # Check if [gps_start, gps_end] intersects with [t0, t1]\n",
" if gps_end > t0 and gps_start < t1:\n",
" files.append(gps_files[t0])\n",
"\n",
" return files\n",
"\n",
"\n",
"def load_strain_segment(gps_start: float, gps_end: float, detector: str = \"L1\") -> Tuple[np.ndarray, np.ndarray]:\n",
" files = get_fnames_for_range(gps_start, gps_end, detector)\n",
" if not files:\n",
" raise ValueError(f\"No files found for {detector} in time range {gps_start}-{gps_end}\")\n",
"\n",
" try:\n",
" # Use GWPy to read the strain data\n",
" strain_ts = TimeSeries.read(files, format='hdf5.gwosc', start=gps_start, end=gps_end)\n",
"\n",
" # Convert to numpy arrays\n",
" times = strain_ts.times.value # Get time array as numpy array\n",
" strain = strain_ts.value # Get strain values as numpy array\n",
"\n",
" return times, strain\n",
"\n",
" except Exception as e:\n",
" raise ValueError(f\"Could not read strain data for {detector}: {e}\")\n",
"\n",
"\n",
"def read_strain_data(file_paths: List[str], gps_start: float, gps_end: float,\n",
" detector: str) -> Tuple[np.ndarray, np.ndarray]:\n",
" if not file_paths:\n",
" raise ValueError(f\"No files provided for {detector}\")\n",
"\n",
" try:\n",
" strain_ts = TimeSeries.read(file_paths, format='hdf5.gwosc', start=gps_start, end=gps_end)\n",
" return strain_ts.times.value, strain_ts.value\n",
"\n",
" except Exception as e:\n",
" raise ValueError(f\"Could not read strain data for {detector} from {len(file_paths)} files: {e}\")\n",
"\n",
"\n",
"def get_strain_data_for_event(event_configs: Dict, detector: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:\n",
" result = {}\n",
"\n",
" # Get analysis data\n",
" if all(key in event_configs for key in ['analysis_start_time', 'analysis_end_time']):\n",
" try:\n",
" analysis_times, analysis_strain = load_strain_segment(\n",
" event_configs['analysis_start_time'],\n",
" event_configs['analysis_end_time'],\n",
" detector\n",
" )\n",
" result['analysis'] = (analysis_times, analysis_strain)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not get analysis data for {detector}: {e}\")\n",
"\n",
" # Get post-event data\n",
" if all(key in event_configs for key in ['postevent_start_time', 'postevent_end_time']):\n",
" try:\n",
" postevent_times, postevent_strain = load_strain_segment(\n",
" event_configs['postevent_start_time'],\n",
" event_configs['postevent_end_time'],\n",
" detector\n",
" )\n",
" result['postevent'] = (postevent_times, postevent_strain)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not get postevent data for {detector}: {e}\")\n",
"\n",
" # Get psd data\n",
" if all(key in event_configs for key in ['psd_start_time', 'psd_end_time']):\n",
" try:\n",
" psd_times, psd_strain = load_strain_segment(\n",
" event_configs['psd_start_time'],\n",
" event_configs['psd_end_time'],\n",
" detector\n",
" )\n",
" result['psd'] = (psd_times, psd_strain)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not get PSD data for {detector}: {e}\")\n",
"\n",
" return result\n",
"\n",
"\n",
"def get_welch_psd(strain_data: np.ndarray, times: np.ndarray,\n",
" analysis_duration: float, roll_off: float = 0.4,\n",
" overlap: float = 0.5) -> Tuple[np.ndarray, np.ndarray]:\n",
" \"\"\"\n",
" Calculate Welch PSD estimate from strain data.\n",
"\n",
" Follows bilby_pipe:\n",
" https://lscsoft.docs.ligo.org/bilby_pipe/master/_modules/bilby_pipe/data_generation.html#DataGenerationInput.__generate_psd\n",
"\n",
" \"\"\"\n",
" # Create TimeSeries object\n",
" strain_ts = TimeSeries(strain_data, times=times)\n",
"\n",
" # Calculate Welch PSD\n",
" psd = strain_ts.psd(\n",
" fftlength=analysis_duration,\n",
" overlap=analysis_duration * overlap,\n",
" window=('tukey', roll_off),\n",
" method='median'\n",
" )\n",
"\n",
" return psd.frequencies.value, psd.value\n",
"\n",
"\n",
"def get_fd_data(strain_data: np.ndarray, times: np.ndarray, det: str, roll_off: float, fmin: float, fmax: float):\n",
" \"\"\"Fixed function with correct variable names and parameters.\"\"\"\n",
" strain_ts = TimeSeries(strain_data, times=times)\n",
" ifo = bilby.gw.detector.get_empty_interferometer(det)\n",
" ifo.strain_data.roll_off = roll_off\n",
" ifo.maximum_frequency = fmax # Fixed: was f_f\n",
" ifo.minimum_frequency = fmin # Fixed: was f_i\n",
" ifo.strain_data.set_from_gwpy_timeseries(strain_ts)\n",
"\n",
" x = ifo.strain_data.frequency_array\n",
" y = ifo.strain_data.frequency_domain_strain\n",
" Ew = np.sqrt(ifo.strain_data.window_factor)\n",
"\n",
" I = (x >= fmin) & (x <= fmax) # Fixed: was f_i and f_f\n",
" return x[I], y[I] / Ew\n",
"\n",
"\n",
"def get_pval(x_data, y_data, x_psd, y_psd):\n",
" # Keep only the data points where the frequencies exactly match\n",
" I_keep = np.isin(x_data, x_psd)\n",
"\n",
" # Make sure the PSD values are aligned\n",
" common_freqs = x_data[I_keep]\n",
" y_psd_matched = y_psd[np.isin(x_psd, common_freqs)]\n",
"\n",
" # Compute the normalized ratio\n",
" ratio_real = np.real(y_data[I_keep] / y_psd_matched)\n",
" ratio_imag = np.imag(y_data[I_keep] / y_psd_matched)\n",
"\n",
" # Combine real and imaginary parts\n",
" ratio_combined = np.concatenate([ratio_real, ratio_imag])\n",
"\n",
" # Anderson-Darling and KS test\n",
" a2_stat = anderson(ratio_combined).statistic\n",
" pvalue = kstest(ratio_combined, norm.cdf).pvalue\n",
"\n",
" return pvalue\n",
"\n",
"\n",
"class GWEventData:\n",
" \"\"\"Class to store and manage gravitational wave event data.\"\"\"\n",
"\n",
" def __init__(self, event_name: str):\n",
" self.event_name = event_name\n",
" self.configs = {}\n",
" self.psds = {} # {detector: {'freqs': array, 'psd': array}}\n",
" self.strain_data = {} # {detector: {'analysis': (times, strain), 'psd': (times, strain)}}\n",
" self.welch_psds = {} # {detector: {'freqs': array, 'psd': array}} - computed from strain\n",
" self.postevent_fd = {} # {detector: {'freqs': array, 'data_fd': array}}\n",
" self.analysis_group = None\n",
" self.fmin = {} # detector-specific minimum frequencies\n",
" self.fmax = {} # detector-specific maximum frequencies\n",
"\n",
" @classmethod\n",
" def from_ozstar(cls, event_name: str): # Fixed: removed extra parameters\n",
" \"\"\"Load event data from OZSTAR GWOSC files.\"\"\"\n",
" instance = cls(event_name)\n",
"\n",
" # Get PE file path\n",
" pe_paths = get_pe_paths()\n",
" if event_name not in pe_paths:\n",
" raise ValueError(f\"Event {event_name} not found in PE data paths\")\n",
"\n",
" pe_file_path = pe_paths[event_name]\n",
"\n",
" # Load configs\n",
" instance.configs = get_gw_event_configs(pe_file_path)\n",
" instance.analysis_group = instance.configs.get('analysis_group')\n",
"\n",
" # Extract and convert config values properly - keep detector-specific values as dicts\n",
" def get_config_value(key):\n",
" \"\"\"Extract config value, preserving detector dictionaries.\"\"\"\n",
" value = instance.configs.get(key)\n",
" if value is None:\n",
" raise KeyError(f\"Config key '{key}' not found\")\n",
"\n",
" # If it's already a dictionary, return it\n",
" if isinstance(value, dict):\n",
" return value\n",
" # If it's a single value, convert to float\n",
" else:\n",
" return float(instance._extract_scalar(value))\n",
"\n",
" fmin = get_config_value('minimum-frequency') # Keep as dict\n",
" fmax = get_config_value('maximum-frequency') # Keep as dict\n",
" analysis_duration = get_config_value('duration')\n",
" overlap = get_config_value('psd-fractional-overlap')\n",
" roll_off = get_config_value('tukey-roll-off')\n",
"\n",
" # Store the frequency dicts in the instance for later use\n",
" instance.fmin = fmin\n",
" instance.fmax = fmax\n",
"\n",
" # Load PSDs from PE file\n",
" with h5py.File(pe_file_path, \"r\") as fin:\n",
" analysis_group = instance.analysis_group or find_analysis_group(fin)\n",
" psd_data = fin[analysis_group][\"psds\"]\n",
"\n",
" for ifo_name in psd_data:\n",
" if ifo_name in ['H1', 'L1']: # Only load H1 and L1\n",
" freqs = psd_data[ifo_name][:, 0]\n",
" vals = psd_data[ifo_name][:, 1]\n",
"\n",
" # Apply frequency mask using detector-specific limits\n",
" detector_fmin = float(instance._extract_scalar(fmin.get(ifo_name, fmin.get('H1', 20.0))))\n",
" detector_fmax = float(instance._extract_scalar(fmax.get(ifo_name, fmax.get('H1', 2048.0))))\n",
"\n",
" mask = (freqs >= detector_fmin) & (freqs <= detector_fmax)\n",
" freqs, vals = freqs[mask], vals[mask]\n",
"\n",
" instance.psds[ifo_name] = {\n",
" 'freqs': freqs,\n",
" 'psd': vals\n",
" }\n",
"\n",
" # Load strain data\n",
" print(f\"Loading strain data for {event_name}...\")\n",
" for detector in ['H1', 'L1']:\n",
" if detector in instance.psds: # Only load for detectors we have PSDs for\n",
" try:\n",
" strain_data = get_strain_data_for_event(instance.configs, detector)\n",
" if strain_data:\n",
" instance.strain_data[detector] = strain_data\n",
" print(f\" {detector}: Loaded {len(strain_data)} data segments\")\n",
"\n",
" # Calculate Welch PSD from strain data\n",
" if 'psd' in strain_data:\n",
" psd_times, psd_strain = strain_data['psd']\n",
"\n",
" # Compute Welch PSD\n",
" welch_freqs, welch_psd = get_welch_psd(\n",
" psd_strain, psd_times, analysis_duration, roll_off, overlap\n",
" )\n",
"\n",
" # Apply frequency mask using detector-specific limits\n",
" detector_fmin = float(instance._extract_scalar(fmin.get(detector, fmin.get('H1', 20.0))))\n",
" detector_fmax = float(instance._extract_scalar(fmax.get(detector, fmax.get('H1', 2048.0))))\n",
"\n",
" mask = (welch_freqs >= detector_fmin) & (welch_freqs <= detector_fmax)\n",
" welch_freqs, welch_psd = welch_freqs[mask], welch_psd[mask]\n",
"\n",
" instance.welch_psds[detector] = {\n",
" 'freqs': welch_freqs,\n",
" 'psd': welch_psd\n",
" }\n",
"\n",
" # get postevent data - Fixed: use 'postevent' instead of 'psd'\n",
" if 'postevent' in strain_data:\n",
" postevent_times, postevent_strain = strain_data['postevent'] # Fixed: was using 'psd'\n",
"\n",
" # Get detector-specific frequency limits\n",
" detector_fmin = float(instance._extract_scalar(fmin.get(detector, fmin.get('H1', 20.0))))\n",
" detector_fmax = float(instance._extract_scalar(fmax.get(detector, fmax.get('H1', 2048.0))))\n",
"\n",
" postevent_freqs, postevent_fd = get_fd_data( # Fixed: was postevnet_fd\n",
" postevent_strain, postevent_times, detector, roll_off, detector_fmin, detector_fmax\n",
" )\n",
" instance.postevent_fd[detector] = {\n",
" 'freqs': postevent_freqs,\n",
" 'datafd': postevent_fd # Fixed: was using different key name\n",
" }\n",
"\n",
" except Exception as e:\n",
" print(f\" Warning: Could not load strain data for {detector}: {e}\")\n",
"\n",
" return instance\n",
"\n",
" def _extract_scalar(self, value):\n",
" \"\"\"Extract scalar value from numpy array or return as-is if already scalar.\"\"\"\n",
" # Handle bytes that need to be decoded and evaluated\n",
" if isinstance(value, bytes):\n",
" try:\n",
" # Decode bytes to string\n",
" value_str = value.decode('utf-8')\n",
" # Try to evaluate as Python literal (dict, list, etc.)\n",
" value = ast.literal_eval(value_str)\n",
" except (UnicodeDecodeError, ValueError, SyntaxError):\n",
" # If decoding/evaluation fails, return as string\n",
" return value.decode('utf-8') if isinstance(value, bytes) else value\n",
"\n",
" # Handle string representations of Python objects\n",
" if isinstance(value, str):\n",
" try:\n",
" value = ast.literal_eval(value)\n",
" except (ValueError, SyntaxError):\n",
" # If evaluation fails, return as string\n",
" return value\n",
"\n",
" # Extract scalar from numpy arrays or single-element sequences\n",
" if hasattr(value, 'item'): # numpy array\n",
" return value.item()\n",
" elif hasattr(value, '__len__') and len(value) == 1: # single-element sequence\n",
" return value[0]\n",
" else:\n",
" return value\n",
"\n",
" def to_hdf5(self, output_dir: str = \".\", filename: str = None):\n",
" \"\"\"Save event data to HDF5 file.\"\"\"\n",
" if filename is None:\n",
" filename = f\"{self.event_name}_data.h5\"\n",
"\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" filepath = os.path.join(output_dir, filename)\n",
"\n",
" with h5py.File(filepath, \"w\") as f:\n",
" # Save metadata\n",
" f.attrs['event_name'] = self.event_name\n",
" f.attrs['analysis_group'] = self.analysis_group or 'unknown'\n",
"\n",
" # Save configs\n",
" config_group = f.create_group(\"configs\")\n",
" for key, value in self.configs.items():\n",
" try:\n",
" if isinstance(value, (str, bytes)):\n",
" config_group.create_dataset(key, data=str(value))\n",
" elif isinstance(value, (int, float, np.number)):\n",
" config_group.create_dataset(key, data=value)\n",
" elif isinstance(value, (list, tuple, np.ndarray)):\n",
" config_group.create_dataset(key, data=value)\n",
" elif isinstance(value, dict):\n",
" config_group.create_dataset(key, data=str(value))\n",
" else:\n",
" config_group.create_dataset(key, data=str(value))\n",
" except Exception as e:\n",
" print(f\"Warning: Could not save config '{key}': {e}\")\n",
"\n",
" # Save PSDs (from PE file)\n",
" psd_group = f.create_group(\"psds\")\n",
" for detector, psd_data in self.psds.items():\n",
" det_group = psd_group.create_group(detector)\n",
" det_group.create_dataset(\"freqs\", data=psd_data['freqs'])\n",
" det_group.create_dataset(\"psd\", data=psd_data['psd'])\n",
"\n",
" # Save Welch PSDs (computed from strain data)\n",
" if self.welch_psds:\n",
" welch_group = f.create_group(\"welch_psds\")\n",
" for detector, psd_data in self.welch_psds.items():\n",
" det_group = welch_group.create_group(detector)\n",
" det_group.create_dataset(\"freqs\", data=psd_data['freqs'])\n",
" det_group.create_dataset(\"psd\", data=psd_data['psd'])\n",
"\n",
" # Save postevent frequency domain data\n",
" if self.postevent_fd:\n",
" fd_group = f.create_group(\"postevent_fd\")\n",
" for detector, fd_data in self.postevent_fd.items():\n",
" det_group = fd_group.create_group(detector)\n",
" det_group.create_dataset(\"freqs\", data=fd_data['freqs'])\n",
" det_group.create_dataset(\"datafd\", data=fd_data['datafd'])\n",
"\n",
" # Save strain data\n",
" if self.strain_data:\n",
" strain_group = f.create_group(\"strain_data\")\n",
" for detector, data_dict in self.strain_data.items():\n",
" det_group = strain_group.create_group(detector)\n",
" for data_type, (times, strain) in data_dict.items():\n",
" type_group = det_group.create_group(data_type)\n",
" type_group.create_dataset(\"times\", data=times)\n",
" type_group.create_dataset(\"strain\", data=strain)\n",
" # Save metadata\n",
" type_group.attrs['duration'] = times[-1] - times[0]\n",
" type_group.attrs['sample_rate'] = len(times) / (times[-1] - times[0])\n",
" type_group.attrs['start_time'] = times[0]\n",
" type_group.attrs['end_time'] = times[-1]\n",
"\n",
" return filepath\n",
"\n",
" @classmethod\n",
" def from_hdf5(cls, filepath: str):\n",
" \"\"\"Load event data from HDF5 file.\"\"\"\n",
" with h5py.File(filepath, \"r\") as f:\n",
" event_name = f.attrs['event_name']\n",
" if isinstance(event_name, bytes):\n",
" event_name = event_name.decode('utf-8')\n",
"\n",
" instance = cls(event_name)\n",
" instance.analysis_group = f.attrs.get('analysis_group', 'unknown')\n",
" if isinstance(instance.analysis_group, bytes):\n",
" instance.analysis_group = instance.analysis_group.decode('utf-8')\n",
"\n",
" # Load configs\n",
" config_group = f[\"configs\"]\n",
" for key in config_group.keys():\n",
" data = config_group[key][()]\n",
" if isinstance(data, bytes):\n",
" data = data.decode('utf-8')\n",
" instance.configs[key] = data\n",
"\n",
" # Load PSDs\n",
" psd_group = f[\"psds\"]\n",
" for detector in psd_group.keys():\n",
" det_group = psd_group[detector]\n",
" instance.psds[detector] = {\n",
" 'freqs': det_group[\"freqs\"][:],\n",
" 'psd': det_group[\"psd\"][:]\n",
" }\n",
"\n",
" # Load Welch PSDs if present\n",
" if \"welch_psds\" in f:\n",
" welch_group = f[\"welch_psds\"]\n",
" for detector in welch_group.keys():\n",
" det_group = welch_group[detector]\n",
" instance.welch_psds[detector] = {\n",
" 'freqs': det_group[\"freqs\"][:],\n",
" 'psd': det_group[\"psd\"][:]\n",
" }\n",
"\n",
" # Load postevent frequency domain data if present\n",
" if \"postevent_fd\" in f:\n",
" fd_group = f[\"postevent_fd\"]\n",
" for detector in fd_group.keys():\n",
" det_group = fd_group[detector]\n",
" instance.postevent_fd[detector] = {\n",
" 'freqs': det_group[\"freqs\"][:],\n",
" 'datafd': det_group[\"datafd\"][:]\n",
" }\n",
"\n",
" # Load strain data if present\n",
" if \"strain_data\" in f:\n",
" strain_group = f[\"strain_data\"]\n",
" for detector in strain_group.keys():\n",
" det_group = strain_group[detector]\n",
" instance.strain_data[detector] = {}\n",
" for data_type in det_group.keys():\n",
" type_group = det_group[data_type]\n",
" times = type_group[\"times\"][:]\n",
" strain = type_group[\"strain\"][:]\n",
" instance.strain_data[detector][data_type] = (times, strain)\n",
"\n",
" return instance\n",
"\n",
" def plot_psds(self, output_dir: str = \".\"):\n",
" \"\"\"Plot PSDs for this event, optionally comparing PE PSDs with Welch estimates.\"\"\"\n",
" if not self.psds:\n",
" print(\"No PSD data to plot\")\n",
" return\n",
"\n",
" # Create subplots for each detector\n",
" n_detectors = len(self.psds)\n",
" fig, axes = plt.subplots(n_detectors, 1, figsize=(10, 4*n_detectors),\n",
" sharex=True, squeeze=False)\n",
" axes = axes.flatten()\n",
"\n",
" colors = {'H1': 'red', 'L1': 'blue'}\n",
"\n",
" for i, detector in enumerate(sorted(self.psds.keys())):\n",
" ax = axes[i]\n",
" psd_data = self.psds[detector]\n",
"\n",
" # Only proceed if we have the required data\n",
" if detector not in self.welch_psds or detector not in self.postevent_fd:\n",
" print(f\"Warning: Missing data for {detector}, skipping plot\")\n",
" continue\n",
"\n",
" welch_data = self.welch_psds[detector]\n",
" fd_data = self.postevent_fd[detector]\n",
"\n",
" # Calculate p-values\n",
" pval_gwtc = get_pval(fd_data['freqs'], fd_data['datafd'], psd_data['freqs'], np.sqrt(psd_data['psd']))\n",
" pval_welch = get_pval(fd_data['freqs'], fd_data['datafd'], welch_data['freqs'], np.sqrt(welch_data['psd']))\n",
"\n",
" # plot postevent FD data\n",
" ax.loglog(fd_data['freqs'], np.abs(fd_data['datafd'])**2, color='lightgray', alpha=0.4, label='Postevent Data')\n",
"\n",
" # Plot GWTC PSD\n",
" ax.loglog(psd_data['freqs'], psd_data['psd'],\n",
" color=colors[detector], linewidth=1.5,\n",
" label=f'GWTC PSD (pval={pval_gwtc:.3f})')\n",
"\n",
" # Plot Welch PSD\n",
" ax.loglog(welch_data['freqs'], welch_data['psd'],\n",
" color=colors[detector], linewidth=2, linestyle='--', alpha=0.3,\n",
" label=f'Welch PSD (pval={pval_welch:.3f})')\n",
"\n",
" ax.set_ylabel(\"PSD [strain²/Hz]\", fontsize=12)\n",
" ax.set_title(f\"{detector} Power Spectral Density\", fontsize=12)\n",
" ax.legend(fontsize=10)\n",
" ax.grid(True, alpha=0.3)\n",
"\n",
" # Set reasonable y-axis limits\n",
" if len(psd_data['psd']) > 0:\n",
" ymin = np.min(psd_data['psd']) * 0.1\n",
" ymax = np.max(psd_data['psd']) * 10\n",
" ax.set_ylim(ymin, ymax)\n",
"\n",
" # Set x-axis label only on bottom subplot\n",
" axes[-1].set_xlabel(\"Frequency [Hz]\", fontsize=12)\n",
"\n",
" # Overall title\n",
" fig.suptitle(f\"{self.event_name} Power Spectral Densities\", fontsize=14, y=0.98)\n",
" plt.tight_layout()\n",
"\n",
" # Save plot\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" plot_path = os.path.join(output_dir, f\"{self.event_name}_psd.png\")\n",
" plt.savefig(plot_path, dpi=150, bbox_inches='tight')\n",
" print(f\"PSD plot saved to: {plot_path}\")\n",
"\n",
" plt.close()\n",
"\n",
"\n",
"def process_and_save_event(event_name: str, output_dir: str = \"event_data\") -> str: # Fixed: removed extra parameters\n",
" \"\"\"Process a single event and save to HDF5.\"\"\"\n",
" try:\n",
" # Load from OZSTAR - Fixed: removed extra parameters\n",
" event_data = GWEventData.from_ozstar(event_name)\n",
"\n",
" # Save to HDF5\n",
" filepath = event_data.to_hdf5(output_dir)\n",
"\n",
" # Generate plots\n",
" event_data.plot_psds(output_dir)\n",
"\n",
" # Only plot strain if we have the data and the method exists\n",
" if event_data.strain_data and hasattr(event_data, 'plot_strain'):\n",
" event_data.plot_strain('analysis', output_dir, time_window=10.0) # 10s window\n",
" event_data.plot_strain('psd', output_dir)\n",
"\n",
" print(f\"Successfully processed {event_name}\")\n",
" print(f\"Data saved to: {filepath}\")\n",
" print(\"-\" * 50)\n",
"\n",
" return filepath\n",
"\n",
" except Exception as e:\n",
" print(f\"Error processing {event_name}: {e}\")\n",
" return None\n",
"\n",
"\n",
"def process_all_events(output_dir: str = \"event_data\"): # Fixed: removed extra parameters\n",
" \"\"\"Process all available events and save to individual HDF5 files.\"\"\"\n",
" pe_paths = get_pe_paths()\n",
"\n",
" print(f\"Found {len(pe_paths)} events to process\")\n",
" print(f\"Output directory: {output_dir}\")\n",
" print(\"=\" * 50)\n",
"\n",
" successful = []\n",
" failed = []\n",
"\n",
" for event_name in tqdm(pe_paths.keys(), desc=\"Processing events\"):\n",
" filepath = process_and_save_event(event_name, output_dir) # Fixed: removed extra parameters\n",
" if filepath:\n",
" successful.append(event_name)\n",
" else:\n",
" failed.append(event_name)\n",
"\n",
" print(f\"\\nProcessing complete:\")\n",
" print(f\"Successful: {len(successful)} events\")\n",
" print(f\"Failed: {len(failed)} events\")\n",
"\n",
" if failed:\n",
" print(f\"Failed events: {failed}\")\n",
"\n",
" return successful, failed\n",
"\n",
"\n",
"def main():\n",
" \"\"\"Main processing function with examples.\"\"\"\n",
" # Initialize paths\n",
" pe_paths = get_pe_paths()\n",
" data_paths = get_data_dicts()\n",
"\n",
" print(f\"Found {len(pe_paths)} PE data files\")\n",
" print(f\"Found {len(data_paths['L1'])} L1 strain files\")\n",
" print(f\"Found {len(data_paths['H1'])} H1 strain files\")\n",
"\n",
" # Example 1: Process a single event with strain data\n",
" if pe_paths:\n",
" first_event = list(pe_paths.keys())[0]\n",
" print(f\"\\n=== Example 1: Processing single event with strain data ===\")\n",
" print(f\"Processing: {first_event}\")\n",
"\n",
" try:\n",
" # Load from OZSTAR with strain data - Fixed: removed extra parameters\n",
" event_data = GWEventData.from_ozstar(first_event)\n",
"\n",
" # Save to HDF5\n",
" filepath = event_data.to_hdf5(\"example_output\")\n",
" print(f\"Saved to: {filepath}\")\n",
"\n",
" # Generate plots\n",
" event_data.plot_psds(\"example_output\")\n",
"\n",
" # Test loading from HDF5\n",
" print(f\"\\n=== Example 2: Loading from HDF5 ===\")\n",
" loaded_data = GWEventData.from_hdf5(filepath)\n",
" print(f\"Loaded: {loaded_data}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Error in examples: {e}\")\n",
" import traceback\n",
" traceback.print_exc()\n",
"\n",
" print(f\"\\n=== Processing all events ===\")\n",
" successful, failed = process_all_events()\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
],
"metadata": {
"id": "wjDTw4j1xQ4j"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "DPDWAjzXojb7"
},
"execution_count": null,
"outputs": []
}
]
}
@avivajpeyi
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment