Skip to content

Instantly share code, notes, and snippets.

@previtus
Last active August 7, 2025 22:27
Show Gist options
  • Select an option

  • Save previtus/c8cd81ce5fce0f11223850d0d1ef7e94 to your computer and use it in GitHub Desktop.

Select an option

Save previtus/c8cd81ce5fce0f11223850d0d1ef7e94 to your computer and use it in GitHub Desktop.
Methane point source detection: Loss weighing by MF products
Weighing the loss of the model during training by the MF (Matched Filter, any variant) methane enhancement product. (Somewhat akin to the way boundaries between cells were weighted up in the original U-Net paper.) This method has showned to work well for the STARCOP paper.
Code refs:
- in STARCOP code we cooked these weight files as extra product (now available as the "weight_mag1c.tif" file in https://zenodo.org/records/7863343): https://github.com/spaceml-org/STARCOP/blob/c4789268a3fa0395f92357429052f6f5fc748acb/starcop/data/feature_extration.py#L32
- in later work we compute these during training (as it's a very fast op), example here in mf_weighing_example.py
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
# Code snippets extracted from the main repo, take this a bit like a pseudo-code suggestion..., adapt it to your own code.
# IN DATA LOADER
class TileDataset(Dataset):
# ... etc
# btw: wmf is set to 0 in no-data areas
# and later the loss in no-data areas can also be masked out
# unless the event of interest is on the edge of the granule, there isn't too much no-data anyway
# example of how the wmf is computed is available in https://github.com/UNEP-IMEO-MARS/mars_mf
# but the same principle applies for any MF variants!
# On the fly weight loss calculation?
if self.settings.trainer.use_weight_loss:
# wmf is already in nice range, usually between 0-1 (and for safety clipped in between -2;+2)
weight_loss = torch.clip(mf_product, 0.1, 1)
# note: in STARCOP we had values in different ranges, so the formula was instead
# torch.clip(mag1c / 400, 0.1, 1)
# ...
# IN TRAINER
class Trainer():
# 1st ...
# it's imporant to retain the spatial resolution of the loss, for example of BCE here, by setting reduction="none"
loss_function = torch.nn.BCEWithLogitsLoss(reduction="none")
# place it on the same device:
if self.settings.trainer.use_weight_loss:
weight_loss = weight_loss.to(model.device)
else:
weight_loss = None
# 2nd ...
def model_forward(self, inputs, weight_loss=None):
# ...
loss = model.forward(inputs)
print("debug print loss.shape=", loss.shape) # for example [16, 1, 128, 128] ~ [batch size, 1, w, h]
# side note: careful here if the model output resolution is decimated as in some transformer architectures - weighing can be done either on the bilinearly upscaled data, or alternatively, we'd need to downscale the weight_loss
# somewhere in mode forward the loss (remember to check that it's actually with spatial dim) is multiplied this way:
if self.settings.trainer.use_weight_loss and weight_loss is not None:
loss = torch.mean(loss * weight_loss)
else:
loss = torch.mean(loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment