Created
March 20, 2019 01:23
-
-
Save kevincdurand1/fc1a446193a169ea846188e8453952a2 to your computer and use it in GitHub Desktop.
series_to_supervised.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pandas import DataFrame | |
| from pandas import concat | |
| def series_to_supervised(data, n_in=1, n_out=1, dropnan=True): | |
| """ | |
| Frame a time series as a supervised learning dataset. | |
| Arguments: | |
| data: Sequence of observations as a list or NumPy array. | |
| n_in: Number of lag observations as input (X). | |
| n_out: Number of observations as output (y). | |
| dropnan: Boolean whether or not to drop rows with NaN values. | |
| Returns: | |
| Pandas DataFrame of series framed for supervised learning. | |
| """ | |
| n_vars = 1 if type(data) is list else data.shape[1] | |
| df = DataFrame(data) | |
| cols, names = list(), list() | |
| # input sequence (t-n, ... t-1) | |
| for i in range(n_in, 0, -1): | |
| cols.append(df.shift(i)) | |
| names += [('var%d(t-%d)' % (j+1, i)) for j in range(n_vars)] | |
| # forecast sequence (t, t+1, ... t+n) | |
| for i in range(0, n_out): | |
| cols.append(df.shift(-i)) | |
| if i == 0: | |
| names += [('var%d(t)' % (j+1)) for j in range(n_vars)] | |
| else: | |
| names += [('var%d(t+%d)' % (j+1, i)) for j in range(n_vars)] | |
| # put it all together | |
| agg = concat(cols, axis=1) | |
| agg.columns = names | |
| # drop rows with NaN values | |
| if dropnan: | |
| agg.dropna(inplace=True) | |
| return agg |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment