Created
May 11, 2023 02:44
-
-
Save thuliumsystems/db562b9fdb2efbbd55d3ae765059704d to your computer and use it in GitHub Desktop.
GridSearchCV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.ensemble import ( | |
| RandomForestClassifier, | |
| AdaBoostClassifier, | |
| GradientBoostingClassifier, | |
| ) | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.naive_bayes import GaussianNB | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.neural_network import MLPClassifier | |
| from sklearn.svm import SVC | |
| from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis | |
| from sklearn.model_selection import GridSearchCV | |
| training = [ | |
| [10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
| [21, 16, 12, 26, 6, 0, 0, 0, 0, 0, 1], | |
| [2, 5, 2, 3, 3, 3, 2, 3, 2, 4, 1], | |
| [13, 6, 3, 4, 4, 1, 1, 0, 0, 1, 1], | |
| [2, 3, 4, 2, 5, 2, 1, 3, 2, 1, 1], | |
| [2, 3, 2, 3, 4, 2, 4, 3, 2, 2, 1], | |
| [1, 0, 1, 1, 1, 4, 33, 2, 9, 4, 0], | |
| [6, 4, 3, 4, 2, 6, 2, 4, 1, 4, 1], | |
| [29, 12, 3, 6, 6, 2, 1, 0, 0, 1, 1], | |
| [0, 0, 0, 0, 0, 10, 27, 8, 8, 13, 0], | |
| [21, 5, 11, 12, 9, 0, 0, 0, 0, 0, 1], | |
| [5, 2, 2, 2, 2, 4, 5, 3, 4, 2, 1], | |
| [14, 20, 6, 8, 8, 0, 0, 0, 0, 0, 1], | |
| [2, 3, 2, 1, 1, 3, 7, 9, 3, 3, 0], | |
| [5, 2, 5, 2, 4, 2, 3, 3, 2, 1, 1], | |
| [8, 6, 4, 3, 2, 3, 4, 2, 1, 3, 1], | |
| [5, 2, 0, 0, 1, 7, 5, 9, 5, 2, 0], | |
| [6, 19, 2, 3, 4, 2, 0, 1, 1, 0, 1], | |
| [0, 0, 0, 0, 0, 6, 16, 6, 19, 22, 0], | |
| [17, 5, 28, 23, 4, 0, 0, 0, 0, 0, 1], | |
| [4, 15, 2, 7, 11, 0, 0, 0, 0, 0, 1], | |
| [3, 5, 4, 3, 8, 1, 2, 1, 2, 3, 1], | |
| [1, 1, 0, 1, 0, 7, 7, 5, 28, 6, 0], | |
| [5, 5, 3, 3, 8, 1, 1, 2, 1, 1, 1], | |
| [2, 1, 1, 1, 1, 9, 7, 8, 2, 6, 0], | |
| [7, 5, 5, 2, 4, 2, 2, 1, 1, 1, 1], | |
| [3, 23, 3, 5, 6, 2, 1, 2, 2, 1, 1], | |
| [1, 1, 1, 0, 1, 8, 5, 3, 2, 6, 0], | |
| [11, 9, 6, 1, 3, 0, 2, 1, 1, 2, 1], | |
| [11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1], | |
| [8, 11, 8, 5, 10, 1, 0, 0, 1, 0, 1], | |
| [16, 19, 6, 8, 2, 0, 0, 0, 0, 1, 1], | |
| [7, 3, 5, 3, 5, 5, 4, 1, 1, 2, 1], | |
| [1, 1, 1, 0, 1, 24, 5, 5, 5, 4, 0], | |
| [1, 0, 1, 0, 1, 12, 8, 10, 7, 4, 0], | |
| [6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1], | |
| [4, 4, 2, 2, 3, 5, 7, 1, 1, 2, 0], | |
| [4, 8, 3, 3, 5, 2, 1, 1, 1, 3, 1], | |
| [9, 10, 5, 7, 5, 1, 1, 1, 0, 0, 1], | |
| [4, 1, 1, 1, 2, 22, 4, 19, 4, 2, 0], | |
| [0, 0, 0, 0, 1, 13, 9, 10, 4, 11, 0], | |
| [14, 14, 9, 4, 4, 1, 1, 0, 0, 0, 1], | |
| [21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1], | |
| [0, 0, 0, 0, 0, 13, 6, 6, 34, 4, 0], | |
| [7, 5, 3, 3, 1, 9, 2, 1, 3, 4, 1], | |
| [0, 0, 0, 0, 0, 6, 10, 4, 13, 30, 0], | |
| [5, 8, 10, 9, 3, 1, 0, 1, 0, 1, 1], | |
| [7, 6, 5, 4, 2, 1, 3, 2, 1, 1, 1], | |
| [1, 1, 1, 1, 1, 5, 5, 7, 5, 5, 0], | |
| [1, 1, 2, 1, 2, 7, 4, 4, 3, 2, 0], | |
| [1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0], | |
| [1, 1, 3, 1, 1, 2, 4, 10, 4, 3, 0], | |
| [1, 2, 1, 1, 3, 5, 8, 9, 3, 7, 0], | |
| [7, 4, 4, 7, 3, 1, 2, 3, 2, 1, 1], | |
| [0, 1, 0, 1, 0, 12, 7, 22, 4, 16, 0], | |
| [1, 4, 3, 1, 1, 4, 5, 11, 2, 7, 0], | |
| [3, 2, 2, 3, 9, 7, 8, 1, 1, 2, 0], | |
| [3, 3, 1, 2, 2, 4, 7, 4, 3, 5, 0], | |
| [6, 7, 29, 3, 6, 0, 1, 0, 1, 1, 1], | |
| [0, 0, 1, 0, 0, 3, 10, 12, 10, 10, 0], | |
| [1, 1, 1, 2, 1, 8, 4, 5, 5, 4, 0], | |
| [0, 1, 0, 1, 0, 4, 22, 7, 4, 3, 0], | |
| [5, 3, 4, 5, 3, 3, 1, 2, 2, 1, 1], | |
| [1, 0, 1, 1, 0, 6, 8, 13, 6, 4, 0], | |
| [2, 1, 1, 0, 0, 18, 6, 38, 5, 8, 0], | |
| [1, 0, 3, 2, 2, 5, 8, 4, 3, 4, 0], | |
| [5, 18, 6, 5, 5, 0, 1, 0, 0, 1, 1], | |
| [0, 1, 2, 1, 3, 4, 6, 12, 2, 13, 0], | |
| [0, 1, 1, 1, 0, 7, 15, 6, 4, 7, 0], | |
| [0, 0, 1, 1, 1, 5, 18, 7, 3, 6, 0], | |
| [6, 7, 5, 9, 4, 1, 1, 0, 1, 1, 1], | |
| [0, 3, 0, 0, 0, 8, 6, 7, 19, 2, 0], | |
| [0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0], | |
| [0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0], | |
| [0, 1, 2, 0, 0, 19, 28, 10, 7, 5, 0], | |
| [17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1], | |
| [20, 9, 7, 8, 14, 0, 0, 0, 0, 0, 1], | |
| [0, 2, 0, 1, 0, 14, 4, 40, 4, 2, 0], | |
| [0, 0, 0, 0, 0, 5, 10, 6, 11, 11, 0], | |
| [2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0], | |
| [17, 3, 3, 4, 12, 2, 1, 0, 1, 1, 1], | |
| [7, 3, 6, 5, 4, 5, 2, 1, 1, 0, 1], | |
| [3, 3, 1, 1, 1, 4, 1, 6, 2, 3, 0], | |
| [4, 5, 3, 2, 4, 1, 4, 1, 1, 0, 1], | |
| [5, 7, 5, 2, 10, 1, 1, 1, 1, 2, 1], | |
| [1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0], | |
| [1, 0, 0, 1, 0, 21, 8, 6, 3, 9, 0], | |
| [8, 3, 1, 3, 5, 2, 3, 0, 2, 2, 1], | |
| [4, 1, 0, 1, 2, 20, 6, 2, 5, 3, 0], | |
| [1, 1, 1, 0, 0, 4, 11, 6, 4, 9, 0], | |
| [9, 3, 8, 2, 2, 2, 3, 0, 1, 1, 1], | |
| [7, 5, 6, 4, 2, 4, 3, 1, 1, 1, 1], | |
| [10, 6, 4, 6, 7, 0, 2, 1, 0, 1, 1], | |
| [4, 6, 5, 8, 8, 1, 0, 1, 1, 0, 1], | |
| [3, 4, 4, 3, 7, 2, 3, 1, 1, 2, 1], | |
| [6, 9, 10, 4, 8, 1, 1, 1, 2, 1, 1], | |
| [5, 9, 4, 3, 3, 1, 1, 0, 0, 1, 1], | |
| [9, 1, 2, 2, 2, 5, 5, 3, 3, 4, 0], | |
| [0, 0, 2, 1, 0, 9, 6, 6, 9, 7, 0], | |
| [14, 11, 7, 4, 2, 2, 1, 1, 1, 1, 1], | |
| [10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
| [7, 6, 7, 4, 22, 0, 1, 0, 0, 1, 1], | |
| [10, 2, 8, 2, 2, 2, 3, 1, 4, 2, 1], | |
| [11, 32, 15, 8, 5, 0, 0, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 0, 3, 14, 7, 6, 5, 0], | |
| [6, 19, 8, 19, 16, 0, 0, 0, 0, 0, 1], | |
| [2, 3, 1, 0, 1, 13, 6, 5, 3, 2, 0], | |
| [17, 17, 36, 3, 3, 0, 0, 0, 1, 1, 1], | |
| [1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0], | |
| [9, 3, 21, 7, 3, 1, 0, 1, 1, 0, 1], | |
| [17, 18, 3, 19, 5, 1, 0, 0, 0, 0, 1], | |
| [1, 3, 1, 1, 2, 11, 3, 4, 3, 4, 0], | |
| [6, 3, 6, 3, 2, 1, 2, 3, 2, 3, 1], | |
| [2, 2, 2, 2, 2, 4, 3, 5, 3, 6, 0], | |
| [6, 14, 11, 4, 17, 0, 0, 1, 0, 0, 1], | |
| [10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
| [25, 5, 3, 3, 3, 1, 1, 0, 1, 2, 1], | |
| [6, 31, 6, 5, 8, 0, 1, 0, 0, 0, 1], | |
| [4, 3, 3, 2, 1, 3, 4, 2, 2, 4, 1], | |
| [1, 1, 1, 1, 0, 5, 5, 13, 13, 10, 0], | |
| [18, 9, 18, 5, 19, 0, 0, 0, 0, 0, 1], | |
| [0, 1, 2, 0, 0, 4, 14, 5, 3, 6, 0], | |
| [5, 6, 2, 3, 2, 2, 2, 2, 2, 1, 1], | |
| [23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1], | |
| [7, 5, 8, 8, 7, 2, 1, 1, 0, 0, 1], | |
| [1, 3, 1, 2, 2, 10, 7, 4, 3, 4, 0], | |
| [3, 3, 2, 2, 2, 2, 6, 4, 2, 3, 1], | |
| [10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1], | |
| [15, 6, 11, 7, 5, 0, 1, 0, 1, 0, 1], | |
| [11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1], | |
| [11, 5, 11, 3, 3, 1, 2, 1, 1, 1, 1], | |
| [8, 29, 6, 9, 7, 1, 0, 0, 0, 1, 1], | |
| [6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1], | |
| [3, 2, 8, 2, 1, 3, 3, 3, 1, 1, 1], | |
| [21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1], | |
| [0, 1, 1, 1, 1, 6, 8, 5, 17, 4, 0], | |
| [19, 16, 12, 4, 4, 0, 0, 1, 1, 0, 1], | |
| [3, 1, 2, 1, 1, 10, 7, 8, 3, 10, 0], | |
| [1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0], | |
| [16, 7, 18, 8, 2, 0, 1, 1, 0, 0, 1], | |
| [0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0], | |
| [0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0], | |
| [17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1], | |
| [2, 1, 0, 1, 1, 14, 10, 6, 5, 5, 0], | |
| [12, 7, 4, 5, 6, 0, 1, 1, 0, 0, 1], | |
| [1, 1, 1, 0, 0, 8, 13, 9, 9, 6, 0], | |
| [1, 2, 1, 1, 1, 18, 7, 9, 3, 1, 0], | |
| [7, 2, 3, 4, 2, 3, 5, 3, 3, 2, 1], | |
| [2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0], | |
| [10, 5, 7, 3, 3, 1, 2, 1, 1, 1, 1], | |
| [8, 9, 5, 5, 2, 2, 2, 3, 0, 0, 1], | |
| [4, 3, 10, 3, 1, 1, 3, 3, 4, 1, 1], | |
| [7, 3, 8, 3, 8, 1, 1, 1, 2, 1, 1], | |
| [1, 3, 4, 1, 3, 2, 2, 5, 2, 1, 1], | |
| [1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0], | |
| [6, 5, 4, 2, 2, 1, 4, 2, 1, 1, 1], | |
| [1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0], | |
| [4, 6, 4, 7, 3, 2, 1, 2, 2, 1, 1], | |
| [4, 4, 16, 5, 16, 3, 0, 1, 0, 1, 1], | |
| [5, 10, 14, 16, 3, 0, 0, 0, 0, 0, 1], | |
| [23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1], | |
| [2, 1, 2, 1, 3, 2, 5, 5, 3, 3, 0], | |
| [5, 5, 8, 4, 2, 3, 3, 1, 1, 2, 1], | |
| [3, 5, 3, 2, 4, 2, 3, 3, 0, 2, 1], | |
| [6, 19, 4, 4, 8, 0, 1, 0, 0, 0, 1], | |
| [1, 2, 0, 1, 1, 19, 3, 6, 9, 3, 0], | |
| [2, 2, 2, 0, 1, 4, 10, 9, 3, 5, 0], | |
| [2, 0, 0, 0, 0, 13, 6, 23, 9, 11, 0], | |
| [5, 34, 14, 14, 4, 0, 0, 1, 1, 0, 1], | |
| [19, 2, 5, 2, 21, 0, 6, 1, 1, 1, 1], | |
| [0, 0, 0, 0, 0, 21, 18, 20, 6, 9, 0], | |
| [13, 16, 22, 5, 13, 0, 0, 0, 0, 0, 1], | |
| [6, 1, 1, 1, 0, 5, 12, 6, 5, 2, 0], | |
| [5, 12, 4, 10, 4, 2, 1, 1, 1, 1, 1], | |
| [17, 18, 6, 8, 10, 1, 0, 0, 0, 0, 1], | |
| [9, 15, 10, 2, 4, 2, 1, 1, 2, 3, 1], | |
| [1, 0, 0, 0, 0, 16, 19, 10, 4, 6, 0], | |
| [6, 5, 3, 1, 3, 3, 2, 2, 1, 1, 1], | |
| [6, 4, 5, 8, 8, 1, 1, 0, 0, 0, 1], | |
| [1, 1, 1, 1, 1, 4, 5, 22, 3, 4, 0], | |
| [2, 2, 1, 0, 1, 7, 11, 6, 2, 3, 0], | |
| [6, 10, 3, 2, 5, 1, 4, 0, 1, 0, 1], | |
| [2, 2, 5, 2, 1, 4, 4, 3, 4, 2, 1], | |
| [2, 1, 0, 1, 1, 5, 12, 7, 7, 4, 0], | |
| [8, 22, 4, 4, 3, 1, 1, 0, 1, 1, 1], | |
| [0, 0, 0, 0, 0, 26, 15, 15, 7, 9, 0], | |
| [10, 5, 3, 3, 3, 2, 2, 3, 1, 1, 1], | |
| [0, 0, 0, 0, 0, 21, 9, 21, 13, 2, 0], | |
| [13, 4, 7, 7, 5, 1, 1, 1, 1, 0, 1], | |
| [12, 25, 6, 2, 6, 5, 1, 0, 0, 1, 1], | |
| [17, 4, 1, 1, 1, 4, 4, 8, 1, 4, 0], | |
| [0, 0, 0, 0, 0, 10, 22, 19, 5, 9, 0], | |
| [2, 5, 4, 2, 2, 9, 2, 2, 1, 1, 0], | |
| ] | |
| scaler = StandardScaler() | |
| scaled_data = scaler.fit_transform(training) | |
| X = [] | |
| for t in training: | |
| X.append(t[:-1]) | |
| y = [last for *_, last in training] | |
| param_rf = { | |
| "n_estimators": [50, 100, 200], | |
| "max_depth": [None, 5, 10], | |
| "min_samples_split": [2, 5, 10], | |
| } | |
| param_dt = { | |
| "max_depth": [None, 5, 10], | |
| "min_samples_split": [2, 5, 10], | |
| "min_samples_leaf": [1, 2, 4], | |
| } | |
| param_lr = {"penalty": ["l1", "l2"], "C": [0.001, 0.01, 0.1, 1, 10]} | |
| param_ab = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]} | |
| param_gbc = {"n_estimators": [50, 100, 200], "learning_rate": [0.001, 0.01, 0.1, 1, 10]} | |
| param_gnb = {"var_smoothing": [1e-9, 1e-8, 1e-7, 1e-6, 1e-5]} | |
| param_knn = {"n_neighbors": [3, 5, 7, 9, 11]} | |
| param_mlp = { | |
| "hidden_layer_sizes": [(10,), (50,), (100,)], | |
| "activation": ["tanh", "relu"], | |
| "alpha": [0.0001, 0.001, 0.01], | |
| } | |
| param_svc = { | |
| "C": [0.1, 1, 10, 100], | |
| "gamma": [0.001, 0.01, 0.1, 1], | |
| "kernel": ["linear", "rbf"], | |
| } | |
| param_qda = {"reg_param": [0.1, 0.5, 1.0]} | |
| def fn_grid_search(model, params, X, y): | |
| grid_search = GridSearchCV(model, params, cv=5) | |
| grid_search.fit(X, y) | |
| print(grid_search.best_score_, "Parameters:", grid_search.best_params_) | |
| rfc = RandomForestClassifier() | |
| dtc = DecisionTreeClassifier() | |
| lr = LogisticRegression() | |
| ab = AdaBoostClassifier() | |
| gbc = GradientBoostingClassifier() | |
| gnb = GaussianNB() | |
| knn = KNeighborsClassifier() | |
| mlp = MLPClassifier() | |
| svc = SVC() | |
| qda = QuadraticDiscriminantAnalysis() | |
| fn_grid_search(rfc, param_rf, X, y) | |
| fn_grid_search(dtc, param_dt, X, y) | |
| # fn_grid_search(lr, param_lr, X, y) | |
| fn_grid_search(ab, param_ab, X, y) | |
| fn_grid_search(gbc, param_gbc, X, y) | |
| fn_grid_search(gnb, param_gnb, X, y) | |
| fn_grid_search(knn, param_knn, X, y) | |
| # fn_grid_search(mlp, param_mlp, X, y) | |
| fn_grid_search(svc, param_svc, X, y) | |
| fn_grid_search(qda, param_qda, X, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment