Skip to content

Instantly share code, notes, and snippets.

@dagrha
Created November 8, 2020 22:53
Show Gist options
  • Select an option

  • Save dagrha/ea419079a8a17572553e5725c06dffac to your computer and use it in GitHub Desktop.

Select an option

Save dagrha/ea419079a8a17572553e5725c06dffac to your computer and use it in GitHub Desktop.
Put in a dictionary r2, equation, and series to plot from simple linear regression for easy use with matplotlib
# © 2020 dagrha. GPLv3.0
def linear_regression(df, x_col, y_col):
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
x = df[(df[x_col].notnull()) &(df[y_col].notnull())][x_col].values.reshape(-1, 1)
y = df[(df[x_col].notnull()) &(df[y_col].notnull())][y_col].values.reshape(-1, 1)
lr = LinearRegression()
lr.fit(x, y)
x_min_buffer = min(x) * 0.5
x_max_buffer = max(x) * 1.5
x_out = np.array([x_min_buffer, x_max_buffer])
y_out = lr.predict(x_out)
r2 = f'r2: {r2_score(y, lr.predict(x)):0.3f}'
eq = f'{y_col} = {lr.coef_[0][0]:.5E} * {x_col} + {lr.intercept_[0]:.5E}'
return {
'x': x_out,
'y': y_out,
'r2': r2,
'eq': eq,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment