mirror of https://github.com/microsoft/autogen.git
transformer example
This commit is contained in:
parent
71acb5140b
commit
d659079a5d
|
@ -0,0 +1,788 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook uses the Huggingface transformers library to finetune a transformer model.\n",
|
||||
"\n",
|
||||
"**Requirements.** This notebook has additional requirements:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install -r transformers_requirements.txt\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MODEL_CHECKPOINT = \"distilbert-base-uncased\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input_ids': [101, 2023, 2003, 1037, 3231, 102], 'attention_mask': [1, 1, 1, 1, 1, 1]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer(\"this is a test\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TASK = \"cola\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Reusing dataset glue (/home/amin/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n",
|
||||
"/home/amin/miniconda/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)\n",
|
||||
" return torch._C._cuda_getDeviceCount() > 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"raw_dataset = datasets.load_dataset(\"glue\", TASK)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# define tokenization function used to process data\n",
|
||||
"COLUMN_NAME = \"sentence\"\n",
|
||||
"def tokenize(examples):\n",
|
||||
" return tokenizer(examples[COLUMN_NAME], truncation=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5bd7b23a478043eaaf6e14e119143fcd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d7b648c2dbdc4fb9907e43da7db8af9a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "36a9d6e62dbe462d94b1769f36fbd0f3",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"encoded_dataset = raw_dataset.map(tokenize, batched=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
||||
" 'idx': 0,\n",
|
||||
" 'input_ids': [101,\n",
|
||||
" 2256,\n",
|
||||
" 2814,\n",
|
||||
" 2180,\n",
|
||||
" 1005,\n",
|
||||
" 1056,\n",
|
||||
" 4965,\n",
|
||||
" 2023,\n",
|
||||
" 4106,\n",
|
||||
" 1010,\n",
|
||||
" 2292,\n",
|
||||
" 2894,\n",
|
||||
" 1996,\n",
|
||||
" 2279,\n",
|
||||
" 2028,\n",
|
||||
" 2057,\n",
|
||||
" 16599,\n",
|
||||
" 1012,\n",
|
||||
" 102],\n",
|
||||
" 'label': 1,\n",
|
||||
" 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"encoded_dataset[\"train\"][0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "35b76e51b5c8406fae416fcdc3dd885e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n",
|
||||
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
|
||||
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"NUM_LABELS = 2\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_CHECKPOINT, num_labels=NUM_LABELS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DistilBertForSequenceClassification(\n",
|
||||
" (distilbert): DistilBertModel(\n",
|
||||
" (embeddings): Embeddings(\n",
|
||||
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
||||
" (position_embeddings): Embedding(512, 768)\n",
|
||||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (transformer): Transformer(\n",
|
||||
" (layer): ModuleList(\n",
|
||||
" (0): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (1): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (2): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (3): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (4): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (5): TransformerBlock(\n",
|
||||
" (attention): MultiHeadSelfAttention(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" (ffn): FFN(\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||||
" )\n",
|
||||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
|
||||
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.2, inplace=False)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric = datasets.load_metric(\"glue\", TASK)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Metric(name: \"glue\", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: \"\"\"\n",
|
||||
"Compute GLUE evaluation metric associated to each GLUE dataset.\n",
|
||||
"Args:\n",
|
||||
" predictions: list of translations to score.\n",
|
||||
" Each translation should be tokenized into a list of tokens.\n",
|
||||
" references: list of lists of references for each translation.\n",
|
||||
" Each reference should be tokenized into a list of tokens.\n",
|
||||
"Returns: depending on the GLUE subset, one or several of:\n",
|
||||
" \"accuracy\": Accuracy\n",
|
||||
" \"f1\": F1\n",
|
||||
" \"pearson\": Pearson Correlation\n",
|
||||
" \"spearmanr\": Spearman Correlation\n",
|
||||
" \"matthews_correlation\": Matthew Correlation\n",
|
||||
"\"\"\", stored examples: 0)"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_metrics(eval_pred):\n",
|
||||
" predictions, labels = eval_pred\n",
|
||||
" predictions = np.argmax(predictions, axis=1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training (aka Finetuning)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Trainer\n",
|
||||
"from transformers import TrainingArguments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = TrainingArguments(\n",
|
||||
" output_dir='output',\n",
|
||||
" do_eval=True,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" args=args,\n",
|
||||
" train_dataset=encoded_dataset[\"train\"],\n",
|
||||
" eval_dataset=encoded_dataset[\"validation\"],\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='322' max='3207' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [ 322/3207 02:51 < 25:41, 1.87 it/s, Epoch 0.30/3]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Step</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Hyperparameter Optimization\n",
|
||||
"\n",
|
||||
"`flaml.tune` is a module for economical hyperparameter tuning. It frees users from manually tuning many hyperparameters for a software, such as machine learning training procedures. \n",
|
||||
"The API is compatible with ray tune.\n",
|
||||
"\n",
|
||||
"### Step 1. Define training method\n",
|
||||
"\n",
|
||||
"We define a function `train_distilbert(config: dict)` that accepts a hyperparameter configuration dict `config`. The specific configs will be generated by flaml's search algorithm in a given search space.\n"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import flaml\n",
|
||||
"\n",
|
||||
"def train_distilbert(config: dict):\n",
|
||||
"\n",
|
||||
" # Define tokenize method\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)\n",
|
||||
" def tokenize(examples):\n",
|
||||
" return tokenizer(examples[COLUMN_NAME], truncation=True)\n",
|
||||
" # Load CoLA dataset and apply tokenizer\n",
|
||||
" cola_raw = load_dataset(\"glue\", TASK)\n",
|
||||
" cola_encoded = cola_raw.map(tokenize, batched=True)\n",
|
||||
" # QUESTION: Write processed data to disk?\n",
|
||||
" train_dataset, eval_dataset = cola_encoded[\"train\"], cola_encoded[\"validation\"]\n",
|
||||
"\n",
|
||||
" model = AutoModelForSequenceClassification.from_pretrained(\n",
|
||||
" MODEL_CHECKPOINT, num_labels=NUM_LABELS\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" metric = load_metric(\"glue\", TASK)\n",
|
||||
"\n",
|
||||
" training_args = TrainingArguments(\n",
|
||||
" output_dir='.',\n",
|
||||
" do_eval=False,\n",
|
||||
" **config,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" eval_dataset=eval_dataset,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # train model\n",
|
||||
" trainer.train()\n",
|
||||
"\n",
|
||||
" # evaluate model\n",
|
||||
" eval_output = trainer.evaluate()\n",
|
||||
"\n",
|
||||
" # report the metric to optimize\n",
|
||||
" flaml.tune.report(\n",
|
||||
" loss=eval_output[\"eval_loss\"],\n",
|
||||
" matthews_correlation=eval_output[\"eval_matthews_correlation\"],\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"### Step 2. Define the search\n",
|
||||
"\n",
|
||||
"We are now ready to define our search. This includes:\n",
|
||||
"\n",
|
||||
"- The `search_space` for our hyperparameters\n",
|
||||
"- The metric and the mode ('max' or 'min') for optimization\n",
|
||||
"- The constraints (`n_cpus`, `n_gpus`, `num_samples`, and `time_budget_s`)"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_num_epoch = 4\n",
|
||||
"search_space = {\n",
|
||||
" # You can mix constants with search space objects.\n",
|
||||
" \"num_train_epochs\": flaml.tune.loguniform(1, max_num_epoch),\n",
|
||||
" \"learning_rate\": flaml.tune.loguniform(1e-6, 1e-4),\n",
|
||||
" \"adam_epsilon\": flaml.tune.loguniform(1e-9, 1e-7),\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# optimization objective\n",
|
||||
"HP_METRIC, MODE = \"matthews_correlation\", \"max\"\n",
|
||||
"\n",
|
||||
"# resources\n",
|
||||
"num_cpus = 2\n",
|
||||
"num_gpus = 2\n",
|
||||
"\n",
|
||||
"# constraints\n",
|
||||
"num_samples = -1 # number of trials, -1 means unlimited\n",
|
||||
"time_budget_s = 3600 # time budget in seconds"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"### Step 3. Launch with `flaml.tune.run`\n",
|
||||
"\n",
|
||||
"We are now ready to laungh the tuning using `flaml.tune.run`:"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import ray\n",
|
||||
"start_time = time.time()\n",
|
||||
"ray.init(num_cpus=num_cpus, num_gpus=num_gpus)\n",
|
||||
"\n",
|
||||
"print(\"Tuning started...\")\n",
|
||||
"analysis = flaml.tune.run(\n",
|
||||
" train_distilbert,\n",
|
||||
" config=search_space,\n",
|
||||
" init_config={\n",
|
||||
" \"num_train_epochs\": 1,\n",
|
||||
" },\n",
|
||||
" metric=HP_METRIC,\n",
|
||||
" mode=MODE,\n",
|
||||
" report_intermediate_result=False,\n",
|
||||
" # uncomment the following if report_intermediate_result = True\n",
|
||||
" # max_resource=max_num_epoch, min_resource=1,\n",
|
||||
" resources_per_trial={\"gpu\": 1},\n",
|
||||
" local_dir='logs/',\n",
|
||||
" num_samples=num_samples,\n",
|
||||
" time_budget_s=time_budget_s,\n",
|
||||
" use_ray=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"ray.shutdown()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"best_trial = analysis.get_best_trial(HP_METRIC, MODE, \"all\")\n",
|
||||
"metric = best_trial.metric_analysis[HP_METRIC][MODE]\n",
|
||||
"print(f\"n_trials={len(analysis.trials)}\")\n",
|
||||
"print(f\"time={time.time()-start_time}\")\n",
|
||||
"print(f\"Best model eval {HP_METRIC}: {metric:.4f}\")\n",
|
||||
"print(f\"Best model parameters: {best_trial.config}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Next Steps\n",
|
||||
"\n",
|
||||
"Notice that we only reported the metric with `flaml.tune.report` at the end of full training loop. It is possible to enable reporting of intermediate performance - allowing early stopping - as follows:\n",
|
||||
"\n",
|
||||
"- Huggingface provides _Callbacks_ which can be used to insert the `flaml.tune.report` call inside the training loop\n",
|
||||
"- Make sure to set `do_eval=True` in the `TrainingArguments` provided to `Trainer` and adjust theevaluation frequency accordingly"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "flaml",
|
||||
"language": "python",
|
||||
"name": "flaml"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
torch
|
||||
transformers
|
||||
datasets
|
||||
ipywidgets
|
Loading…
Reference in New Issue