Skip to content

Instantly share code, notes, and snippets.

@juliuskittler
Created November 3, 2022 12:14
Show Gist options
  • Select an option

  • Save juliuskittler/c1cfc61f348180d0207549433bb2f162 to your computer and use it in GitHub Desktop.

Select an option

Save juliuskittler/c1cfc61f348180d0207549433bb2f162 to your computer and use it in GitHub Desktop.
"""
This is an adjusted version of the following file: https://github.com/outerbounds/dsbook/blob/main/chapter-3/classifier_train.py
"""
from metaflow import FlowSpec, step, project, conda_base
@project(name="dummy_example")
@conda_base(python="3.8.13", libraries={"pandas": "1.5", "scikit-learn": "1.1.2"})
class ClassifierTrainFlow(FlowSpec):
@step
def start(self):
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, y = datasets.load_wine(return_X_y=True)
self.train_data,\
self.test_data,\
self.train_labels,\
self.test_labels = train_test_split(X, y, test_size=0.2, random_state=0)
self.next(self.train_knn, self.train_svm)
@step
def train_knn(self):
from sklearn.neighbors import KNeighborsClassifier
self.model = KNeighborsClassifier()
self.model.fit(self.train_data, self.train_labels)
self.next(self.choose_model)
@step
def train_svm(self):
from sklearn import svm
self.model = svm.SVC(kernel="poly")
self.model.fit(self.train_data, self.train_labels)
self.next(self.choose_model)
@step
def choose_model(self, inputs):
def score(inp):
return inp.model, inp.model.score(inp.test_data, inp.test_labels)
self.results = sorted(map(score, inputs), key=lambda x: -x[1])
self.model = self.results[0][0]
self.next(self.end)
@step
def end(self):
print('Scores:')
print('\n'.join('%s %f' % res for res in self.results))
if __name__ == '__main__':
ClassifierTrainFlow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment