This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import pandas as pd | |
| class DecisionTree: | |
| def __init__(self, max_depth = 6, depth = 1): | |
| self.max_depth = max_depth | |
| self.depth = depth | |
| self.left = None | |
| self.right = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def predict(self, data): | |
| return np.array([self.__flow_data_thru_tree(row) for _, row in data.iterrows()]) | |
| def __flow_data_thru_tree(self, row): | |
| if self.is_leaf_node: return self.probability | |
| tree = self.left if row[self.split_feature] <= self.criteria else self.right | |
| return tree.__flow_data_thru_tree(row) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @property | |
| def is_leaf_node(self): return self.left is None | |
| @property | |
| def probability(self): | |
| return self.data[self.target].value_counts().apply(lambda x: x/len(self.data)).tolist() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __init__(self, max_depth = 4, depth = 1): | |
| self.max_depth = max_depth | |
| self.depth = depth | |
| self.left = None | |
| self.right = None | |
| def __create_branches(self): | |
| self.left = DecisionTree(max_depth = self.max_depth, | |
| depth = self.depth + 1) | |
| self.right = DecisionTree(max_depth = self.max_depth, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __create_branches(self): | |
| self.left = DecisionTree() | |
| self.right = DecisionTree() | |
| left_rows = self.data[self.data[self.split_feature] <= self.criteria] | |
| right_rows = self.data[self.data[self.split_feature] > self.criteria] | |
| self.left.fit(data = left_rows, target = self.target) | |
| self.right.fit(data = right_rows, target = self.target) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __find_best_split(self): | |
| best_split = {} | |
| for col in self.independent: | |
| information_gain, split = self.__find_best_split_for_column(col) | |
| if split is None: continue | |
| if not best_split or best_split["information_gain"] < information_gain: | |
| best_split = {"split": split, | |
| "col": col, | |
| "information_gain": information_gain} | |
| return best_split["split"], best_split["col"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __find_best_split_for_column(self, col): | |
| x = self.data[col] | |
| unique_values = x.unique() | |
| if len(unique_values) == 1: return None, None | |
| information_gain = None | |
| split = None | |
| for val in unique_values: | |
| left = x <= val | |
| right = x > val | |
| left_data = self.data[left] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __calculate_impurity_score(self, data): | |
| if data is None or data.empty: return 0 | |
| p_i, _ = data.value_counts().apply(lambda x: x/len(data)).tolist() | |
| return p_i * (1 - p_i) * 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __init__(self): | |
| self.left = None | |
| self.right = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class DecisionTree: | |
| def fit(self, data, target): | |
| self.data = data | |
| self.target = target | |
| self.independent = self.data.columns.tolist() | |
| self.independent.remove(target) | |
| def predict(self, data): | |
| return np.array([self.__flow_data_thru_tree(row) for row in data.values]) |
NewerOlder