Last active
November 8, 2019 13:07
-
-
Save lucasc896/eacd91069e85d85dfb8bc25ecdeaa7e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from numpy import array as np_array | |
| from sklearn.metrics import confusion_matrix | |
| from torch import Tensor as pytorch_tensor | |
| def ensure_array(obj): | |
| if isinstance(obj, pytorch_tensor): | |
| return obj.numpy() | |
| return np_array(obj) | |
| class Confusion: | |
| def __init__(self, *, ytrue, ypred, classes): | |
| self.ytrue = ensure_array(ytrue) | |
| self.ypred = ensure_array(ypred) | |
| self.classes = ensure_array(classes) | |
| self.class_to_idx = { | |
| class_label: n for n, class_label in enumerate(self.classes) | |
| } | |
| self.total = self.ytrue.size | |
| self.confusion_matrix = confusion_matrix( | |
| y_true=self.ytrue, y_pred=self.ypred, labels=self.classes | |
| ) | |
| def p(self, class_label): | |
| class_idx = self.class_to_idx[class_label] | |
| return self.confusion_matrix[class_idx, :].sum() | |
| def n(self, class_label): | |
| return self.total - self.p(class_label) | |
| def tp(self, class_label): | |
| class_idx = self.class_to_idx[class_label] | |
| return self.confusion_matrix[class_idx, class_idx] | |
| def fp(self, class_label): | |
| class_idx = self.class_to_idx[class_label] | |
| return self.confusion_matrix[:, class_idx].sum() - self.tp(class_label) | |
| def tn(self, class_label): | |
| class_idx = self.class_to_idx[class_label] | |
| return ( | |
| self.confusion_matrix.sum() | |
| - self.confusion_matrix[class_idx, :].sum() | |
| - self.confusion_matrix[:, class_idx].sum() | |
| + self.tp(class_label) | |
| ) | |
| def fn(self, class_label): | |
| class_idx = self.class_to_idx[class_label] | |
| return self.confusion_matrix[class_idx, :].sum() - self.tp(class_label) | |
| def tpr(self, class_label): | |
| return self.tp(class_label) / self.p(class_label) | |
| def fnr(self, class_label): | |
| return 1.0 - self.tpr(class_label) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment