Last active
September 1, 2019 14:47
-
-
Save RobertTLange/d7594680fd2f35247ffd09a30c08a6f8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train_logistic_regression(n, d, n_epoch, batch_size, b_init, l_rate): | |
| # Generate the data for a coefficient vector & init progress tracker! | |
| data_loader = DataLoader(n, d, batch_size, binary=True) | |
| b_hist, func_val_hist, param_error, acc_hist = [], [], [], [] | |
| # Get the coefficients as solution to optimized sklearn function | |
| logreg = LogisticRegression(penalty='none', solver='lbfgs', multi_class='multinomial') | |
| logreg.fit(data_loader.X, data_loader.y) | |
| norm_coeff = np.linalg.norm(logreg.coef_.ravel()) | |
| b_dual = DualTensor(b_init, None) | |
| # Start running the training loop | |
| for epoch in range(n_epoch): | |
| # Shuffle the batch identities at beginning of each epoch | |
| data_loader.shuffle_arrays() | |
| for batch_id in range(data_loader.num_batches): | |
| # Clear the gradient | |
| b_dual.zero_grad() | |
| # Select the current batch & perform "mini-forward" pass | |
| X, y = data_loader.get_batch_idx(batch_id) | |
| y_pred_1, y_pred_2 = forward(X, b_dual) | |
| # Calculate the forward AD - real = func, dual = deriv | |
| current_dual, acc = binary_cross_entropy_dual(y, y_pred_1, y_pred_2) | |
| # Perform grad step & append results to the placeholder list | |
| b_dual.real -= l_rate*np.array(current_dual.dual).flatten() | |
| b_hist.append(b_dual.real) | |
| func_val_hist.append(current_dual.real) | |
| param_error.append(np.linalg.norm(logreg.coef_.ravel() - b_hist[-1])/norm_coeff) | |
| acc_hist.append(acc) | |
| if np.abs(param_error[-1] - param_error[-2]) < 0.00001: | |
| break | |
| if epoch % 1 == 0: | |
| print("Accuracy: {} | Euclidean Param Norm: {} | fct min: {}".format(acc, param_error[-1], current_dual.real)) | |
| return b_hist, func_val_hist, param_error, acc_hist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment