Skip to content

Instantly share code, notes, and snippets.

@darenr
Created March 2, 2022 19:37
Show Gist options
  • Select an option

  • Save darenr/43256ab1eaba281a9b5063ef298d013e to your computer and use it in GitHub Desktop.

Select an option

Save darenr/43256ab1eaba281a9b5063ef298d013e to your computer and use it in GitHub Desktop.
XGBoost with Tensorboard
import datetime
import xgboost as xgb
from tensorboardX import SummaryWriter
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
class TensorBoardCallback(xgb.callback.TrainingCallback):
def __init__(self, experiment: str = None, data_name: str = None):
self.experiment = experiment or "logs"
self.data_name = data_name or "test"
self.datetime_ = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
self.log_dir = f"runs/{self.experiment}/{self.datetime_}"
self.train_writer = SummaryWriter(log_dir=os.path.join(self.log_dir, "train/"))
if self.data_name:
self.test_writer = SummaryWriter(
log_dir=os.path.join(self.log_dir, f"{self.data_name}/")
)
def after_iteration(
self, model, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
if not evals_log:
return False
for data, metric in evals_log.items():
for metric_name, log in metric.items():
score = log[-1][0] if isinstance(log[-1], tuple) else log[-1]
if data == "train":
self.train_writer.add_scalar(metric_name, score, epoch)
else:
self.test_writer.add_scalar(metric_name, score, epoch)
return False
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
params = {'objective':'reg:squarederror', 'eval_metric': 'rmse'}
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train'), (dtest, 'test')],
callbacks=[TensorBoardCallback(experiment='exp_1', data_name='test')])
@darenr
Copy link
Author

darenr commented Mar 2, 2022

XGBoost w/ Tensorboard

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment