Last active
August 18, 2022 16:27
-
-
Save andrewm4894/ce57c09848580c9deb46eb244bb16cea to your computer and use it in GitHub Desktop.
huggingface_text_classification_quickstart.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "huggingface_text_classification_quickstart.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [], | |
| "toc_visible": true, | |
| "authorship_tag": "ABX9TyO9LEu653YJxFqtsnJARjcs", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU", | |
| "gpuClass": "standard", | |
| "widgets": { | |
| "application/vnd.jupyter.widget-state+json": { | |
| "a8e4f6e9602844d78c9302424d7256a0": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HBoxModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HBoxModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HBoxView", | |
| "box_style": "", | |
| "children": [ | |
| "IPY_MODEL_00d72cc80c8e48e2afa4ceff4e0a6b08", | |
| "IPY_MODEL_9369dc5fa6ae427ba50baf8349fe2555", | |
| "IPY_MODEL_203482a6a68b4fd1ae0ab841f178bf51" | |
| ], | |
| "layout": "IPY_MODEL_7444d67a9a554790a439597dea8731e4" | |
| } | |
| }, | |
| "00d72cc80c8e48e2afa4ceff4e0a6b08": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_3152b2c31f6a4aa1a561371dfd2581e7", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_cfc7e65db5204a4491d397c76caf6da1", | |
| "value": "100%" | |
| } | |
| }, | |
| "9369dc5fa6ae427ba50baf8349fe2555": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "FloatProgressModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "FloatProgressModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "ProgressView", | |
| "bar_style": "success", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_2e65ebacbf084c28bf002b6f3c916493", | |
| "max": 3, | |
| "min": 0, | |
| "orientation": "horizontal", | |
| "style": "IPY_MODEL_bc491f5684054d81a8da052fcd51e93e", | |
| "value": 3 | |
| } | |
| }, | |
| "203482a6a68b4fd1ae0ab841f178bf51": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_ce7193ae0a4e4955a06fff282fa17930", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_cbc7ad0787fc4d96b516e3a72703cb42", | |
| "value": " 3/3 [00:00<00:00, 42.07it/s]" | |
| } | |
| }, | |
| "7444d67a9a554790a439597dea8731e4": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "3152b2c31f6a4aa1a561371dfd2581e7": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "cfc7e65db5204a4491d397c76caf6da1": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "2e65ebacbf084c28bf002b6f3c916493": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "bc491f5684054d81a8da052fcd51e93e": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "ProgressStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "ProgressStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "bar_color": null, | |
| "description_width": "" | |
| } | |
| }, | |
| "ce7193ae0a4e4955a06fff282fa17930": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "cbc7ad0787fc4d96b516e3a72703cb42": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "b46476db3f544c5399869f53a56ad7ac": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HBoxModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HBoxModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HBoxView", | |
| "box_style": "", | |
| "children": [ | |
| "IPY_MODEL_9af08056884d4261a35f48a526741c4b", | |
| "IPY_MODEL_ffcf1849d5794d7088a8b773dc9eb1e2", | |
| "IPY_MODEL_2c4896f55c704d01ae15b518516f77f7" | |
| ], | |
| "layout": "IPY_MODEL_d0b71eddd6ee46ee863846eab816dfbf" | |
| } | |
| }, | |
| "9af08056884d4261a35f48a526741c4b": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_b21b3ca8130f4d209df66b661fcf5622", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_a1578b17612e4c4da1faec59082eb9b4", | |
| "value": "100%" | |
| } | |
| }, | |
| "ffcf1849d5794d7088a8b773dc9eb1e2": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "FloatProgressModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "FloatProgressModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "ProgressView", | |
| "bar_style": "success", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_e939bf80bb9f42cd9f0094e83ca7dea0", | |
| "max": 1, | |
| "min": 0, | |
| "orientation": "horizontal", | |
| "style": "IPY_MODEL_d5a05cb08336485f92088400c05a49c7", | |
| "value": 1 | |
| } | |
| }, | |
| "2c4896f55c704d01ae15b518516f77f7": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_a7f1fbd79d764a8480c3dd1ab9ba0f66", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_8073faf1478541398b28d30b7c9b026d", | |
| "value": " 1/1 [00:00<00:00, 1.44ba/s]" | |
| } | |
| }, | |
| "d0b71eddd6ee46ee863846eab816dfbf": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "b21b3ca8130f4d209df66b661fcf5622": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "a1578b17612e4c4da1faec59082eb9b4": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "e939bf80bb9f42cd9f0094e83ca7dea0": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "d5a05cb08336485f92088400c05a49c7": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "ProgressStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "ProgressStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "bar_color": null, | |
| "description_width": "" | |
| } | |
| }, | |
| "a7f1fbd79d764a8480c3dd1ab9ba0f66": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "8073faf1478541398b28d30b7c9b026d": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "3c3ee1b9872246c08f3e653ecd889fc6": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HBoxModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HBoxModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HBoxView", | |
| "box_style": "", | |
| "children": [ | |
| "IPY_MODEL_f9670e54b6494dd29b3995b829f04e92", | |
| "IPY_MODEL_638b4e2a800348c89074e74bfb5a40eb", | |
| "IPY_MODEL_d546b7955af3437592a383a89acef679" | |
| ], | |
| "layout": "IPY_MODEL_e70bb17b54814092a53831578316ec7c" | |
| } | |
| }, | |
| "f9670e54b6494dd29b3995b829f04e92": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_041e1c9dba454ea1bf164572d80c8ba6", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_99024854e78a4cb29c96e135c773542f", | |
| "value": "100%" | |
| } | |
| }, | |
| "638b4e2a800348c89074e74bfb5a40eb": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "FloatProgressModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "FloatProgressModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "ProgressView", | |
| "bar_style": "success", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_cbed4cabd860432080d23c0ac5b2e336", | |
| "max": 1, | |
| "min": 0, | |
| "orientation": "horizontal", | |
| "style": "IPY_MODEL_cd1aafa89c954e60949191d2bf8763f9", | |
| "value": 1 | |
| } | |
| }, | |
| "d546b7955af3437592a383a89acef679": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_9d2c369dbc73497e9672972e97702f08", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_4bae1123b9dd4ccfab8d4b6baa0c8137", | |
| "value": " 1/1 [00:00<00:00, 1.46ba/s]" | |
| } | |
| }, | |
| "e70bb17b54814092a53831578316ec7c": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "041e1c9dba454ea1bf164572d80c8ba6": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "99024854e78a4cb29c96e135c773542f": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "cbed4cabd860432080d23c0ac5b2e336": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "cd1aafa89c954e60949191d2bf8763f9": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "ProgressStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "ProgressStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "bar_color": null, | |
| "description_width": "" | |
| } | |
| }, | |
| "9d2c369dbc73497e9672972e97702f08": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "4bae1123b9dd4ccfab8d4b6baa0c8137": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| } | |
| } | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/andrewm4894/ce57c09848580c9deb46eb244bb16cea/huggingface_text_classification_quickstart.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# install what we need if not already installed\n", | |
| "#!pip install datasets transformers" | |
| ], | |
| "metadata": { | |
| "id": "0USd3nerpZ96" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "8abgt_d7owDr" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# minimal example here should work for either 'pytorch' or 'tensorflow'\n", | |
| "framework = 'tensorflow'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Imports" | |
| ], | |
| "metadata": { | |
| "id": "Gx7igJNVo-g5" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import pandas as pd\n", | |
| "\n", | |
| "from datasets import load_dataset, Dataset, DatasetDict, Value, ClassLabel, Features\n", | |
| "from transformers import DataCollatorWithPadding, AutoTokenizer, pipeline\n", | |
| "\n", | |
| "if framework == 'tensorflow':\n", | |
| " from transformers import TFAutoModelForSequenceClassification, create_optimizer\n", | |
| " import tensorflow as tf\n", | |
| "else:\n", | |
| " from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", | |
| "\n", | |
| "\n", | |
| "def preprocess_function(examples):\n", | |
| " return tokenizer(examples[\"text\"], truncation=True)\n" | |
| ], | |
| "metadata": { | |
| "id": "immGFdDxo6UC" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Params" | |
| ], | |
| "metadata": { | |
| "id": "PqMIFVy3qlag" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# downsample data for speed in this example\n", | |
| "n_train = 1000\n", | |
| "n_test = 1000\n", | |
| "\n", | |
| "# ml inputs\n", | |
| "batch_size = 16\n", | |
| "learning_rate = 2e-5\n", | |
| "num_epochs = 2\n", | |
| "weight_decay = 0.01" | |
| ], | |
| "metadata": { | |
| "id": "656s4baDqm_i" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Get Data" | |
| ], | |
| "metadata": { | |
| "id": "CJ85wluppAg6" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# load imdb data\n", | |
| "data = load_dataset(\"imdb\")\n", | |
| "\n", | |
| "# pull data into pandas dataframes to downsample\n", | |
| "df_train = pd.DataFrame.from_dict(data['train']).sample(n_train)\n", | |
| "df_test = pd.DataFrame.from_dict(data['test']).sample(n_test)\n", | |
| "\n", | |
| "# now build back up a DatasetDict based on the downsampled data\n", | |
| "\n", | |
| "# define the features\n", | |
| "features = Features({\n", | |
| " \"text\": Value(\"string\"), \n", | |
| " \"label\": ClassLabel(num_classes=2, names=['neg','pos']),\n", | |
| " \"__index_level_0__\": Value(\"string\") \n", | |
| " })\n", | |
| "\n", | |
| "# recreate the data object using the smaller df's\n", | |
| "data = DatasetDict({\n", | |
| " 'train': Dataset.from_pandas(df_train, features=features),\n", | |
| " 'test': Dataset.from_pandas(df_test, features=features),\n", | |
| " })\n", | |
| "\n", | |
| "# remove index col (seems to be coming in from pandas for some reason)\n", | |
| "data = data.remove_columns([\"__index_level_0__\"])\n", | |
| "\n", | |
| "# look at data\n", | |
| "print(data['train'].features)\n", | |
| "print(data['test'].features)\n", | |
| "print(data)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 295, | |
| "referenced_widgets": [ | |
| "a8e4f6e9602844d78c9302424d7256a0", | |
| "00d72cc80c8e48e2afa4ceff4e0a6b08", | |
| "9369dc5fa6ae427ba50baf8349fe2555", | |
| "203482a6a68b4fd1ae0ab841f178bf51", | |
| "7444d67a9a554790a439597dea8731e4", | |
| "3152b2c31f6a4aa1a561371dfd2581e7", | |
| "cfc7e65db5204a4491d397c76caf6da1", | |
| "2e65ebacbf084c28bf002b6f3c916493", | |
| "bc491f5684054d81a8da052fcd51e93e", | |
| "ce7193ae0a4e4955a06fff282fa17930", | |
| "cbc7ad0787fc4d96b516e3a72703cb42" | |
| ] | |
| }, | |
| "id": "AJ8d7B01o7fL", | |
| "outputId": "aa44cf3b-229b-44fe-a6f5-561ae494a4c4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "WARNING:datasets.builder:Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| " 0%| | 0/3 [00:00<?, ?it/s]" | |
| ], | |
| "application/vnd.jupyter.widget-view+json": { | |
| "version_major": 2, | |
| "version_minor": 0, | |
| "model_id": "a8e4f6e9602844d78c9302424d7256a0" | |
| } | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "{'text': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['neg', 'pos'], id=None)}\n", | |
| "{'text': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['neg', 'pos'], id=None)}\n", | |
| "DatasetDict({\n", | |
| " train: Dataset({\n", | |
| " features: ['text', 'label'],\n", | |
| " num_rows: 1000\n", | |
| " })\n", | |
| " test: Dataset({\n", | |
| " features: ['text', 'label'],\n", | |
| " num_rows: 1000\n", | |
| " })\n", | |
| "})\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Tokenize" | |
| ], | |
| "metadata": { | |
| "id": "n2GaP1HHpKmi" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# create tokenizer\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n", | |
| "\n", | |
| "# tokenize the data\n", | |
| "tokenized_data = data.map(preprocess_function, batched=True)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 81, | |
| "referenced_widgets": [ | |
| "b46476db3f544c5399869f53a56ad7ac", | |
| "9af08056884d4261a35f48a526741c4b", | |
| "ffcf1849d5794d7088a8b773dc9eb1e2", | |
| "2c4896f55c704d01ae15b518516f77f7", | |
| "d0b71eddd6ee46ee863846eab816dfbf", | |
| "b21b3ca8130f4d209df66b661fcf5622", | |
| "a1578b17612e4c4da1faec59082eb9b4", | |
| "e939bf80bb9f42cd9f0094e83ca7dea0", | |
| "d5a05cb08336485f92088400c05a49c7", | |
| "a7f1fbd79d764a8480c3dd1ab9ba0f66", | |
| "8073faf1478541398b28d30b7c9b026d", | |
| "3c3ee1b9872246c08f3e653ecd889fc6", | |
| "f9670e54b6494dd29b3995b829f04e92", | |
| "638b4e2a800348c89074e74bfb5a40eb", | |
| "d546b7955af3437592a383a89acef679", | |
| "e70bb17b54814092a53831578316ec7c", | |
| "041e1c9dba454ea1bf164572d80c8ba6", | |
| "99024854e78a4cb29c96e135c773542f", | |
| "cbed4cabd860432080d23c0ac5b2e336", | |
| "cd1aafa89c954e60949191d2bf8763f9", | |
| "9d2c369dbc73497e9672972e97702f08", | |
| "4bae1123b9dd4ccfab8d4b6baa0c8137" | |
| ] | |
| }, | |
| "id": "Pv--vigxpOYa", | |
| "outputId": "49539298-69cb-46f5-ef00-ebfcbfecb1d9" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| " 0%| | 0/1 [00:00<?, ?ba/s]" | |
| ], | |
| "application/vnd.jupyter.widget-view+json": { | |
| "version_major": 2, | |
| "version_minor": 0, | |
| "model_id": "b46476db3f544c5399869f53a56ad7ac" | |
| } | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| " 0%| | 0/1 [00:00<?, ?ba/s]" | |
| ], | |
| "application/vnd.jupyter.widget-view+json": { | |
| "version_major": 2, | |
| "version_minor": 0, | |
| "model_id": "3c3ee1b9872246c08f3e653ecd889fc6" | |
| } | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Train" | |
| ], | |
| "metadata": { | |
| "id": "Rinpt0BOpH9K" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# train mode based on framework\n", | |
| "if framework == 'pytorch':\n", | |
| " \n", | |
| " data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n", | |
| " \n", | |
| " model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", | |
| " \n", | |
| " training_args = TrainingArguments(\n", | |
| " output_dir=\"./results\",\n", | |
| " learning_rate=learning_rate,\n", | |
| " per_device_train_batch_size=batch_size,\n", | |
| " per_device_eval_batch_size=batch_size,\n", | |
| " num_train_epochs=num_epochs,\n", | |
| " weight_decay=weight_decay,\n", | |
| " )\n", | |
| " \n", | |
| " trainer = Trainer(\n", | |
| " model=model,\n", | |
| " args=training_args,\n", | |
| " train_dataset=tokenized_data[\"train\"],\n", | |
| " eval_dataset=tokenized_data[\"test\"],\n", | |
| " tokenizer=tokenizer,\n", | |
| " data_collator=data_collator,\n", | |
| " ) \n", | |
| " \n", | |
| " trainer.train()\n", | |
| " \n", | |
| "elif framework == 'tensorflow':\n", | |
| " \n", | |
| " data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n", | |
| " \n", | |
| " tf_train_set = tokenized_data[\"train\"].to_tf_dataset(\n", | |
| " columns=[\"attention_mask\", \"input_ids\", \"label\"],\n", | |
| " shuffle=True,\n", | |
| " batch_size=batch_size,\n", | |
| " collate_fn=data_collator,\n", | |
| " )\n", | |
| "\n", | |
| " tf_validation_set = tokenized_data[\"test\"].to_tf_dataset(\n", | |
| " columns=[\"attention_mask\", \"input_ids\", \"label\"],\n", | |
| " shuffle=False,\n", | |
| " batch_size=batch_size,\n", | |
| " collate_fn=data_collator,\n", | |
| " )\n", | |
| " \n", | |
| " batches_per_epoch = len(tokenized_data[\"train\"]) // batch_size\n", | |
| " total_train_steps = int(batches_per_epoch * num_epochs)\n", | |
| " optimizer, schedule = create_optimizer(init_lr=learning_rate, num_warmup_steps=0, num_train_steps=total_train_steps)\n", | |
| " \n", | |
| " model = TFAutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", | |
| " \n", | |
| " model.compile(optimizer=optimizer)\n", | |
| " \n", | |
| " model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=num_epochs)\n", | |
| "\n", | |
| "else: \n", | |
| " \n", | |
| " raise ValueError('unsupported framework')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "m6X5UBpGpHMa", | |
| "outputId": "693b69fd-8e68-4e22-94ab-6bcab53a1fa5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForSequenceClassification: ['vocab_transform', 'vocab_projector', 'activation_13', 'vocab_layer_norm']\n", | |
| "- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
| "- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
| "Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['dropout_19', 'classifier', 'pre_classifier']\n", | |
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | |
| "No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Epoch 1/2\n", | |
| "63/63 [==============================] - 56s 726ms/step - loss: 0.5812 - val_loss: 0.3526\n", | |
| "Epoch 2/2\n", | |
| "63/63 [==============================] - 44s 702ms/step - loss: 0.2713 - val_loss: 0.2748\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Inference Pipeline" | |
| ], | |
| "metadata": { | |
| "id": "a_6YFAMPpQqa" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# create pipeline for inference\n", | |
| "classifier = pipeline(\n", | |
| " task=\"text-classification\", \n", | |
| " model=model, \n", | |
| " tokenizer=tokenizer, \n", | |
| " device=0\n", | |
| " )" | |
| ], | |
| "metadata": { | |
| "id": "4AkV11XJpTbB" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "classifier(\"this is a great movie\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "5vwRTe1OpUsS", | |
| "outputId": "524ec120-9877-41f1-9ca3-5b0ff3d3a789" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "[{'label': 'LABEL_1', 'score': 0.9083077311515808}]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "classifier(\"this is a terrible movie\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "lb2OEWvXpVP7", | |
| "outputId": "a343b03a-649e-456b-d7ab-60c2a2f96330" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "[{'label': 'LABEL_0', 'score': 0.7608168125152588}]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 10 | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment