Skip to content

Instantly share code, notes, and snippets.

@swati210994
Created October 15, 2020 16:55
Show Gist options
  • Select an option

  • Save swati210994/99be601a80859c8c85b6dc4d14ca99ad to your computer and use it in GitHub Desktop.

Select an option

Save swati210994/99be601a80859c8c85b6dc4d14ca99ad to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F1 score 0.9811320754716981\n",
"Classification Report\n",
" precision recall f1-score support\n",
"\n",
" ham 0.99 1.00 1.00 954\n",
" spam 0.99 0.97 0.98 161\n",
"\n",
" accuracy 0.99 1115\n",
" macro avg 0.99 0.98 0.99 1115\n",
"weighted avg 0.99 0.99 0.99 1115\n",
"\n",
"Training and saving built model.....\n"
]
}
],
"source": [
"model_save_path='./bert_model.h5'\n",
"\n",
"trained_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)\n",
"trained_model.compile(loss=loss,optimizer=optimizer, metrics=[metric])\n",
"trained_model.load_weights(model_save_path)\n",
"\n",
"preds = trained_model.predict([val_inp,val_mask],batch_size=32)\n",
"pred_labels = preds.argmax(axis=1)\n",
"f1 = f1_score(val_label,pred_labels)\n",
"print('F1 score',f1)\n",
"print('Classification Report')\n",
"print(classification_report(val_label,pred_labels,target_names=target_names))\n",
"\n",
"print('Training and saving built model.....') "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@NonlinearNimesh
Copy link

TFSequenceClassifierOutput' object has no attribute 'argmax'

can you please help me get rid of this error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment