mirror of https://github.com/microsoft/autogen.git
add notebook (#109)
* added support for transformers==3.4.0 * updating error message * adding arxiv
This commit is contained in:
parent
183b867856
commit
cd4be9c0e5
|
@ -1,4 +1,8 @@
|
|||
How to use AutoTransformers:
|
||||
# Hyperparameter Optimization for Huggingface Transformers
|
||||
|
||||
AutoTransformers is an AutoML class for fine-tuning pre-trained language models based on the transformers library.
|
||||
|
||||
An example of using AutoTransformers:
|
||||
|
||||
```python
|
||||
from flaml.nlp.autotransformers import AutoTransformers
|
||||
|
@ -29,4 +33,12 @@ The current use cases that are supported:
|
|||
|
||||
The use cases that can be supported in future:
|
||||
1. HPO fine-tuning for text generation;
|
||||
2. HPO fine-tuning for question answering;
|
||||
2. HPO fine-tuning for question answering;
|
||||
|
||||
### Troubleshooting fine-tuning HPO for pre-trained language models
|
||||
|
||||
To reproduce the results for our ACL2021 paper:
|
||||
|
||||
*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. To appear in ACL-IJCNLP 2021*
|
||||
|
||||
Please refer to the following jupyter notebook: [Troubleshooting HPO for fine-tuning pre-trained language models](https://github.com/microsoft/FLAML/blob/main/notebook/research/acl2021.ipynb)
|
|
@ -6,6 +6,7 @@ import logging
|
|||
|
||||
try:
|
||||
import ray
|
||||
import transformers
|
||||
from transformers import TrainingArguments
|
||||
import datasets
|
||||
import torch
|
||||
|
@ -16,11 +17,6 @@ from .dataset.task_auto import get_default_task
|
|||
from .result_analysis.azure_utils import JobID
|
||||
from .huggingface.trainer import TrainerForAutoTransformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger_formatter = logging.Formatter(
|
||||
'[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s',
|
||||
'%m-%d %H:%M:%S')
|
||||
|
||||
task_list = [
|
||||
"seq-classification",
|
||||
"regression",
|
||||
|
@ -173,7 +169,7 @@ class AutoTransformers:
|
|||
if is_wandb_on:
|
||||
from .result_analysis.wandb_utils import WandbUtils
|
||||
self.wandb_utils = WandbUtils(is_wandb_on=is_wandb_on,
|
||||
console_args=console_args,
|
||||
wandb_key_path=console_args.key_path,
|
||||
jobid_config=self.jobid_config)
|
||||
self.wandb_utils.set_wandb_per_run()
|
||||
else:
|
||||
|
@ -358,9 +354,8 @@ class AutoTransformers:
|
|||
return training_args_config, per_model_config
|
||||
|
||||
def _objective(self, config, reporter, checkpoint_dir=None):
|
||||
from transformers import IntervalStrategy
|
||||
|
||||
from transformers.trainer_utils import set_seed
|
||||
self._set_transformers_verbosity(self._transformers_verbose)
|
||||
|
||||
def model_init():
|
||||
return self._load_model()
|
||||
|
@ -377,17 +372,32 @@ class AutoTransformers:
|
|||
batch_size=config["per_device_train_batch_size"])
|
||||
|
||||
assert self.path_utils.ckpt_dir_per_trial
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.path_utils.ckpt_dir_per_trial,
|
||||
do_eval=False,
|
||||
per_device_eval_batch_size=32,
|
||||
eval_steps=ckpt_freq,
|
||||
evaluation_strategy=IntervalStrategy.STEPS,
|
||||
save_steps=ckpt_freq,
|
||||
save_total_limit=0,
|
||||
fp16=self._fp16,
|
||||
**training_args_config,
|
||||
)
|
||||
|
||||
if transformers.__version__.startswith("3"):
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.path_utils.ckpt_dir_per_trial,
|
||||
do_eval=True,
|
||||
per_device_eval_batch_size=32,
|
||||
eval_steps=ckpt_freq,
|
||||
evaluate_during_training=True,
|
||||
save_steps=ckpt_freq,
|
||||
save_total_limit=0,
|
||||
fp16=self._fp16,
|
||||
**training_args_config,
|
||||
)
|
||||
else:
|
||||
from transformers import IntervalStrategy
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.path_utils.ckpt_dir_per_trial,
|
||||
do_eval=True,
|
||||
per_device_eval_batch_size=32,
|
||||
eval_steps=ckpt_freq,
|
||||
evaluation_strategy=IntervalStrategy.STEPS,
|
||||
save_steps=ckpt_freq,
|
||||
save_total_limit=0,
|
||||
fp16=self._fp16,
|
||||
**training_args_config,
|
||||
)
|
||||
|
||||
trainer = TrainerForAutoTransformers(
|
||||
this_model,
|
||||
|
@ -398,7 +408,6 @@ class AutoTransformers:
|
|||
tokenizer=self._tokenizer,
|
||||
compute_metrics=self._compute_metrics_by_dataset_name,
|
||||
)
|
||||
trainer.logger = logger
|
||||
trainer.trial_id = reporter.trial_id
|
||||
|
||||
"""
|
||||
|
@ -498,7 +507,7 @@ class AutoTransformers:
|
|||
ckpt_json = json.load(open(ckpt_dir))
|
||||
return ckpt_json["best_ckpt"]
|
||||
except FileNotFoundError as err:
|
||||
logger.error("Saved checkpoint not found. Please make sure checkpoint is stored under {}".format(ckpt_dir))
|
||||
print("Saved checkpoint not found. Please make sure checkpoint is stored under {}".format(ckpt_dir))
|
||||
raise err
|
||||
|
||||
def _set_metric(self, custom_metric_name=None, custom_metric_mode_name=None):
|
||||
|
@ -511,14 +520,12 @@ class AutoTransformers:
|
|||
subdataset_name=self.jobid_config.subdat,
|
||||
custom_metric_name=custom_metric_name,
|
||||
custom_metric_mode_name=custom_metric_mode_name)
|
||||
_variable_override_default_alternative(logger,
|
||||
self,
|
||||
_variable_override_default_alternative(self,
|
||||
"metric_name",
|
||||
default_metric,
|
||||
all_metrics,
|
||||
custom_metric_name)
|
||||
_variable_override_default_alternative(logger,
|
||||
self,
|
||||
_variable_override_default_alternative(self,
|
||||
"metric_mode_name",
|
||||
default_mode,
|
||||
all_modes,
|
||||
|
@ -620,6 +627,7 @@ class AutoTransformers:
|
|||
resources_per_trial=resources_per_trial)
|
||||
duration = time.time() - start_time
|
||||
self.last_run_duration = duration
|
||||
print("Total running time: {} seconds".format(duration))
|
||||
|
||||
hp_dict = best_run.hyperparameters
|
||||
hp_dict["seed"] = int(hp_dict["seed"])
|
||||
|
@ -650,6 +658,18 @@ class AutoTransformers:
|
|||
|
||||
return validation_metric
|
||||
|
||||
def _set_transformers_verbosity(self, transformers_verbose):
|
||||
if transformers_verbose == transformers.logging.ERROR:
|
||||
transformers.logging.set_verbosity_error()
|
||||
elif transformers_verbose == transformers.logging.WARNING:
|
||||
transformers.logging.set_verbosity_warning()
|
||||
elif transformers_verbose == transformers.logging.INFO:
|
||||
transformers.logging.set_verbosity_info()
|
||||
elif transformers_verbose == transformers.logging.DEBUG:
|
||||
transformers.logging.set_verbosity_debug()
|
||||
else:
|
||||
raise Exception("transformers_verbose must be set to ERROR, WARNING, INFO or DEBUG")
|
||||
|
||||
def fit(self,
|
||||
num_samples,
|
||||
time_budget,
|
||||
|
@ -657,7 +677,8 @@ class AutoTransformers:
|
|||
custom_metric_mode_name=None,
|
||||
ckpt_per_epoch=1,
|
||||
fp16=True,
|
||||
verbose=1,
|
||||
ray_verbose=1,
|
||||
transformers_verbose=10,
|
||||
resources_per_trial=None,
|
||||
ray_local_mode=False,
|
||||
**custom_hpo_args):
|
||||
|
@ -688,9 +709,12 @@ class AutoTransformers:
|
|||
e.g., "max", "min", "last", "all"
|
||||
ckpt_per_epoch:
|
||||
An integer value of number of checkpoints per epoch, default = 1
|
||||
verbose:
|
||||
int, default=1 | Controls the verbosity, higher means more
|
||||
messages
|
||||
ray_verbose:
|
||||
int, default=1 | verbosit of ray,
|
||||
transformers_verbose:
|
||||
int, default=transformers.logging.INFO | verbosity of transformers, must be chosen from one of
|
||||
transformers.logging.ERROR, transformers.logging.INFO, transformers.logging.WARNING,
|
||||
or transformers.logging.DEBUG
|
||||
fp16:
|
||||
boolean, default = True | whether to use fp16
|
||||
ray_local_mode:
|
||||
|
@ -709,6 +733,7 @@ class AutoTransformers:
|
|||
|
||||
'''
|
||||
from .hpo.scheduler_auto import AutoScheduler
|
||||
self._transformers_verbose = transformers_verbose
|
||||
|
||||
"""
|
||||
Specify the other parse of jobid configs from custom_hpo_args, e.g., if the search algorithm was not specified
|
||||
|
@ -729,12 +754,6 @@ class AutoTransformers:
|
|||
self.ckpt_per_epoch = ckpt_per_epoch
|
||||
self.path_utils.make_dir_per_run()
|
||||
|
||||
logger.addHandler(logging.FileHandler(os.path.join(self.path_utils.log_dir_per_run, 'tune.log')))
|
||||
old_level = logger.getEffectiveLevel()
|
||||
self._verbose = verbose
|
||||
if verbose == 0:
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
assert self.path_utils.ckpt_dir_per_run
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -748,7 +767,7 @@ class AutoTransformers:
|
|||
name="ray_result",
|
||||
resources_per_trial=resources_per_trial,
|
||||
config=tune_config,
|
||||
verbose=verbose,
|
||||
verbose=ray_verbose,
|
||||
local_dir=self.path_utils.ckpt_dir_per_run,
|
||||
num_samples=num_samples,
|
||||
time_budget_s=time_budget,
|
||||
|
@ -758,7 +777,7 @@ class AutoTransformers:
|
|||
)
|
||||
duration = time.time() - start_time
|
||||
self.last_run_duration = duration
|
||||
logger.info("Total running time: {} seconds".format(duration))
|
||||
print("Total running time: {} seconds".format(duration))
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -774,9 +793,6 @@ class AutoTransformers:
|
|||
|
||||
self._save_ckpt_json(best_ckpt)
|
||||
|
||||
if verbose == 0:
|
||||
logger.setLevel(old_level)
|
||||
|
||||
return validation_metric, analysis
|
||||
|
||||
def predict(self,
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from transformers.models.electra.modeling_electra import ElectraClassificationHead
|
||||
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
|
||||
import transformers
|
||||
|
||||
if transformers.__version__.startswith("3"):
|
||||
from transformers.modeling_electra import ElectraClassificationHead
|
||||
from transformers.modeling_roberta import RobertaClassificationHead
|
||||
else:
|
||||
from transformers.models.electra.modeling_electra import ElectraClassificationHead
|
||||
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
|
||||
|
||||
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
|
||||
[
|
||||
|
|
|
@ -55,12 +55,43 @@ class ConfigScoreList:
|
|||
|
||||
def get_best_config(self,
|
||||
metric_mode="max"):
|
||||
return max(self._config_score_list, key=lambda x: getattr(x, "metric_score")
|
||||
[metric_mode])
|
||||
return max(self._config_score_list, key=lambda x: getattr(x, "metric_score")[metric_mode])
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobID:
|
||||
"""
|
||||
The class for specifying the config of a job, includes the following fields:
|
||||
|
||||
dat:
|
||||
A list which is the dataset name
|
||||
subdat:
|
||||
A string which is the sub dataset name
|
||||
mod:
|
||||
A string which is the module, e.g., "grid", "hpo"
|
||||
spa:
|
||||
A string which is the space mode, e.g., "uni", "gnr"
|
||||
arg:
|
||||
A string which is the mode for setting the input argument of a search algorithm, e.g., "cus", "dft"
|
||||
alg:
|
||||
A string which is the search algorithm name
|
||||
pru:
|
||||
A string which is the scheduler name
|
||||
pre_full:
|
||||
A string which is the full name of the pretrained language model
|
||||
pre:
|
||||
A string which is the abbreviation of the pretrained language model
|
||||
presz:
|
||||
A string which is the size of the pretrained language model
|
||||
spt:
|
||||
A string which is the resplit mode, e.g., "ori", "rspt"
|
||||
rep:
|
||||
An integer which is the repetition id
|
||||
sddt:
|
||||
An integer which is the seed for data shuffling in the resplit mode
|
||||
sdhf:
|
||||
An integer which is the seed for transformers
|
||||
"""
|
||||
dat: list = field(default=None)
|
||||
subdat: str = field(default=None)
|
||||
mod: str = field(default=None)
|
||||
|
@ -313,8 +344,8 @@ class JobID:
|
|||
"rep_id": "rep",
|
||||
"seed_data": "sddt",
|
||||
"seed_transformers": "sdhf",
|
||||
"optarg1": "var1",
|
||||
"optarg2": "var2"
|
||||
"learning_rate": "var1",
|
||||
"weight_decay": "var2"
|
||||
}
|
||||
for each_key in console_to_jobid_key_mapping.keys():
|
||||
try:
|
||||
|
@ -346,30 +377,120 @@ class AzureUtils:
|
|||
|
||||
def __init__(self,
|
||||
root_log_path=None,
|
||||
console_args=None,
|
||||
autohf=None):
|
||||
from ..utils import get_wandb_azure_key
|
||||
azure_key_path=None,
|
||||
data_root_dir=None,
|
||||
autohf=None,
|
||||
jobid_config=None):
|
||||
''' This class is for saving the output files (logs, predictions) for HPO, uploading it to an azure storage
|
||||
blob, and performing analysis on the saved blobs. To use the cloud storage, you need to specify a key
|
||||
and upload the output files to azure. For example, when running jobs in a cluster, this class can
|
||||
help you store all the output files in the same place. If a key is not specified, this class will help you
|
||||
save the files locally but not uploading to the cloud. After the outputs are uploaded, you can use this
|
||||
class to perform analysis on the uploaded blob files.
|
||||
|
||||
Examples:
|
||||
|
||||
Example 1 (saving and uploading):
|
||||
|
||||
validation_metric, analysis = autohf.fit(**autohf_settings) # running HPO
|
||||
predictions, test_metric = autohf.predict()
|
||||
|
||||
azure_utils = AzureUtils(root_log_path="logs_test/",
|
||||
autohf=autohf,
|
||||
azure_key_path="../../")
|
||||
# passing the azure blob key from key.json under azure_key_path
|
||||
|
||||
azure_utils.write_autohf_output(valid_metric=validation_metric,
|
||||
predictions=predictions,
|
||||
duration=autohf.last_run_duration)
|
||||
# uploading the output to azure cloud, which can be used for analysis afterwards
|
||||
|
||||
Example 2 (analysis):
|
||||
|
||||
jobid_config = JobID()
|
||||
jobid_config.mod = "grid"
|
||||
jobid_config.pre = "funnel"
|
||||
jobid_config.presz = "xlarge"
|
||||
|
||||
azure_utils = AzureUtils(root_log_path= "logs_test/",
|
||||
azure_key_path = "../../",
|
||||
jobid_config=jobid_config)
|
||||
|
||||
# continue analyzing all files in azure blob that matches jobid_config
|
||||
|
||||
Args:
|
||||
root_log_path:
|
||||
The local root log folder name, e.g., root_log_path="logs_test/" will create a directory
|
||||
"logs_test/" locally
|
||||
|
||||
azure_key_path:
|
||||
The path for storing the azure keys. The azure key, and container name are stored in a local file
|
||||
azure_key_path/key.json. The key_path.json file should look like this:
|
||||
|
||||
{
|
||||
"container_name": "container_name",
|
||||
"azure_key": "azure_key",
|
||||
}
|
||||
|
||||
To find out the container name and azure key of your blob, please refer to:
|
||||
https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-portal
|
||||
|
||||
If the container name and azure key are not specified, the output will only be saved locally,
|
||||
not synced to azure blob.
|
||||
|
||||
data_root_dir:
|
||||
The directory for outputing the predictions, e.g., packing the predictions into a .zip file for
|
||||
uploading to the glue website
|
||||
|
||||
autohf:
|
||||
The AutoTransformers object, which contains the output of an HPO run. AzureUtils will save the
|
||||
output (analysis results, predictions) from AzureTransformers.
|
||||
|
||||
jobid_config:
|
||||
The jobid config for analysis. jobid_config specifies the jobid config of azure blob files
|
||||
to be analyzed, if autohf is specified, jobid_config will be overwritten by autohf.jobid_config
|
||||
'''
|
||||
if root_log_path:
|
||||
self.root_log_path = root_log_path
|
||||
else:
|
||||
self.root_log_path = "logs_azure"
|
||||
self.jobid = autohf.jobid_config
|
||||
self.console_args = console_args
|
||||
self.root_log_path = "logs_azure/"
|
||||
if autohf is not None:
|
||||
self.jobid = autohf.jobid_config
|
||||
else:
|
||||
assert jobid_config is not None, "jobid_config must be passed either through autohf.jobid_config" \
|
||||
" or jobid_config"
|
||||
self.jobid = jobid_config
|
||||
self.data_root_dir = data_root_dir
|
||||
self.autohf = autohf
|
||||
if console_args:
|
||||
wandb_key, azure_key, container_name = get_wandb_azure_key(console_args.key_path)
|
||||
if azure_key_path:
|
||||
azure_key, container_name = AzureUtils.get_azure_key(azure_key_path)
|
||||
self._container_name = container_name
|
||||
self._azure_key = azure_key
|
||||
else:
|
||||
self._container_name = self._azure_key = ""
|
||||
|
||||
@staticmethod
|
||||
def get_azure_key(key_path):
|
||||
try:
|
||||
try:
|
||||
key_json = json.load(open(os.path.join(key_path, "key.json"), "r"))
|
||||
azure_key = key_json["azure_key"]
|
||||
azure_container_name = key_json["container_name"]
|
||||
return azure_key, azure_container_name
|
||||
except FileNotFoundError:
|
||||
print("Your output will not be synced to azure because key.json is not found under key_path")
|
||||
return "", ""
|
||||
except KeyError:
|
||||
print("Your output will not be synced to azure because azure key and container name are not specified")
|
||||
return "", ""
|
||||
|
||||
def _get_complete_connection_string(self):
|
||||
try:
|
||||
return "DefaultEndpointsProtocol=https;AccountName=docws5141197765;AccountKey=" \
|
||||
+ self._azure_key + ";EndpointSuffix=core.windows.net"
|
||||
+ self._azure_key + ";EndpointSuffix=core.windows.net"
|
||||
except AttributeError:
|
||||
return "DefaultEndpointsProtocol=https;AccountName=docws5141197765;AccountKey=" \
|
||||
";EndpointSuffix=core.windows.net"
|
||||
";EndpointSuffix=core.windows.net"
|
||||
|
||||
def _init_azure_clients(self):
|
||||
try:
|
||||
|
@ -380,11 +501,10 @@ class AzureUtils:
|
|||
container_name=self._container_name)
|
||||
return container_client
|
||||
except ValueError:
|
||||
print("AzureUtils._container_name is specified as: {}, "
|
||||
"please correctly specify AzureUtils._container_name".format(self._container_name))
|
||||
print("Your output will not be synced to azure because azure key and container name are not specified")
|
||||
return None
|
||||
except ImportError:
|
||||
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
||||
print("Your output will not be synced to azure because azure-blob-storage is not installed")
|
||||
|
||||
def _init_blob_client(self,
|
||||
local_file_path):
|
||||
|
@ -397,10 +517,10 @@ class AzureUtils:
|
|||
blob_client = blob_service_client.get_blob_client(container=self._container_name, blob=local_file_path)
|
||||
return blob_client
|
||||
except ValueError:
|
||||
print("_container_name is unspecified or wrongly specified, please specify _container_name in AzureUtils")
|
||||
print("Your output will not be synced to azure because azure key and container name are not specified")
|
||||
return None
|
||||
except ImportError:
|
||||
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
||||
print("Your output will not be synced to azure because azure-blob-storage is not installed")
|
||||
|
||||
def upload_local_file_to_azure(self, local_file_path):
|
||||
try:
|
||||
|
@ -412,9 +532,9 @@ class AzureUtils:
|
|||
blob_client.upload_blob(fin, overwrite=True)
|
||||
except HttpResponseError as err:
|
||||
print("Cannot upload blob due to {}: {}".format("azure.core.exceptions.HttpResponseError",
|
||||
err))
|
||||
err))
|
||||
except ImportError:
|
||||
print("To use the azure storage component in flaml.nlp, run pip install azure-storage-blob")
|
||||
print("Your output will not be synced to azure because azure-blob-storage is not installed")
|
||||
|
||||
def download_azure_blob(self, blobname):
|
||||
blob_client = self._init_blob_client(blobname)
|
||||
|
@ -497,13 +617,14 @@ class AzureUtils:
|
|||
store predictions (a .zip file) locally and upload
|
||||
"""
|
||||
azure_save_file_name = local_json_file.split("/")[-1][:-5]
|
||||
try:
|
||||
output_dir = self.console_args.data_root_dir
|
||||
except AttributeError:
|
||||
print("console_args does not contain data_root_dir, loading the default value")
|
||||
if self.data_root_dir is None:
|
||||
from ..utils import load_dft_args
|
||||
console_args = load_dft_args()
|
||||
output_dir = getattr(console_args, "data_root_dir")
|
||||
print("The path for saving the prediction .zip file is not specified, "
|
||||
"setting to {} by default".format(output_dir))
|
||||
else:
|
||||
output_dir = self.data_root_dir
|
||||
local_archive_path = self.autohf.output_prediction(predictions,
|
||||
output_prediction_path=output_dir + "result/",
|
||||
output_zip_file_name=azure_save_file_name)
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import subprocess
|
||||
import hashlib
|
||||
from time import time
|
||||
import json
|
||||
|
||||
|
||||
class WandbUtils:
|
||||
|
@ -25,11 +26,10 @@ class WandbUtils:
|
|||
|
||||
def __init__(self,
|
||||
is_wandb_on=False,
|
||||
console_args=None,
|
||||
wandb_key_path=None,
|
||||
jobid_config=None):
|
||||
if is_wandb_on:
|
||||
from ..utils import get_wandb_azure_key
|
||||
wandb_key, azure_key, container_name = get_wandb_azure_key(console_args.key_path)
|
||||
wandb_key = WandbUtils.get_wandb_key(wandb_key_path)
|
||||
if wandb_key != "":
|
||||
subprocess.run(["wandb", "login", "--relogin", wandb_key])
|
||||
os.environ["WANDB_API_KEY"] = wandb_key
|
||||
|
@ -38,6 +38,20 @@ class WandbUtils:
|
|||
os.environ["WANDB_MODE"] = "disabled"
|
||||
self.jobid_config = jobid_config
|
||||
|
||||
@staticmethod
|
||||
def get_wandb_key(key_path):
|
||||
try:
|
||||
try:
|
||||
key_json = json.load(open(os.path.join(key_path, "key.json"), "r"))
|
||||
wandb_key = key_json["wandb_key"]
|
||||
return wandb_key
|
||||
except FileNotFoundError:
|
||||
print("Cannot use wandb module because key.json is not found under key_path")
|
||||
return ""
|
||||
except KeyError:
|
||||
print("Cannot use wandb module because wandb key is not specified")
|
||||
return ""
|
||||
|
||||
def set_wandb_per_trial(self):
|
||||
print("before wandb.init\n\n\n")
|
||||
try:
|
||||
|
@ -57,7 +71,7 @@ class WandbUtils:
|
|||
print(err)
|
||||
return None
|
||||
except ImportError:
|
||||
print("To use the wandb component in flaml.nlp, run pip install wandb==0.10.26")
|
||||
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")
|
||||
|
||||
@staticmethod
|
||||
def _get_next_trial_ids():
|
||||
|
@ -84,4 +98,4 @@ class WandbUtils:
|
|||
print(err)
|
||||
return None
|
||||
except ImportError:
|
||||
print("To use the wandb component in flaml.nlp, run pip install wandb==0.10.26")
|
||||
print("Cannot use wandb module because wandb is not installed, run pip install wandb==0.10.26")
|
||||
|
|
|
@ -56,24 +56,12 @@ def load_dft_args():
|
|||
arg_parser.add_argument('--round_idx', type=int, help='round idx for acl experiments', required=False, default=0)
|
||||
arg_parser.add_argument('--seed_data', type=int, help='seed of data shuffling', required=False, default=43)
|
||||
arg_parser.add_argument('--seed_transformers', type=int, help='seed of transformers', required=False, default=42)
|
||||
arg_parser.add_argument('--optarg1', type=float, help='place holder for optional arg', required=False)
|
||||
arg_parser.add_argument('--optarg2', type=float, help='place holder for optional arg', required=False)
|
||||
arg_parser.add_argument('--learning_rate', type=float, help='optional arg learning_rate', required=False)
|
||||
arg_parser.add_argument('--weight_decay', type=float, help='optional arg weight_decay', required=False)
|
||||
args, unknown = arg_parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_wandb_azure_key(key_path):
|
||||
try:
|
||||
key_json = json.load(open(os.path.join(key_path, "key.json"), "r"))
|
||||
wandb_key = key_json["wandb_key"]
|
||||
azure_key = key_json["azure_key"]
|
||||
azure_container_name = key_json["container_name"]
|
||||
return wandb_key, azure_key, azure_container_name
|
||||
except FileNotFoundError:
|
||||
print("File not found for key.json under", key_path)
|
||||
return "", "", ""
|
||||
|
||||
|
||||
def merge_dicts(dict1, dict2):
|
||||
for key2 in dict2.keys():
|
||||
if key2 in dict1:
|
||||
|
@ -91,7 +79,7 @@ def _check_dict_keys_overlaps(dict1: dict, dict2: dict):
|
|||
return len(dict1_keys.intersection(dict2_keys)) > 0
|
||||
|
||||
|
||||
def _variable_override_default_alternative(logger, obj_ref, var_name, default_value, all_values, overriding_value=None):
|
||||
def _variable_override_default_alternative(obj_ref, var_name, default_value, all_values, overriding_value=None):
|
||||
"""
|
||||
Setting the value of var. If overriding_value is specified, var is set to overriding_value;
|
||||
If overriding_value is not specified, var is set to default_value meanwhile showing all_values
|
||||
|
@ -99,11 +87,11 @@ def _variable_override_default_alternative(logger, obj_ref, var_name, default_va
|
|||
assert isinstance(all_values, list)
|
||||
if overriding_value:
|
||||
setattr(obj_ref, var_name, overriding_value)
|
||||
logger.warning("The value for {} is specified as {}".format(var_name, overriding_value))
|
||||
print("The value for {} is specified as {}".format(var_name, overriding_value))
|
||||
else:
|
||||
setattr(obj_ref, var_name, default_value)
|
||||
logger.warning("The value for {} is not specified, setting it to the default value {}. "
|
||||
"Alternatively, you can set it to {}".format(var_name, default_value, ",".join(all_values)))
|
||||
print("The value for {} is not specified, setting it to the default value {}. "
|
||||
"Alternatively, you can set it to {}".format(var_name, default_value, ",".join(all_values)))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -0,0 +1,814 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) 2020-2021. All rights reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the MIT License.\n",
|
||||
"\n",
|
||||
"# Troubleshooting HPO for fine-tuning pre-trained language models\n",
|
||||
"\n",
|
||||
"## 1. Introduction\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In this notebook, we demonstrate a procedure for troubleshooting HPO failure in fine-tuning pre-trained language models (introduced in the following paper):\n",
|
||||
"\n",
|
||||
"*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. To appear in ACL-IJCNLP 2021*\n",
|
||||
"\n",
|
||||
"Notes:\n",
|
||||
"\n",
|
||||
"*In this notebook, we only run each experiment 1 time for simplicity, which is different from the paper (3 times). To reproduce the paper's result, please run 3 repetitions and take the average scores.\n",
|
||||
"\n",
|
||||
"*Running this notebook takes about one hour.\n",
|
||||
"\n",
|
||||
"FLAML requires `Python>=3.6`. To run this notebook example, please install flaml with the `notebook` and `nlp` options:\n",
|
||||
"```bash\n",
|
||||
"pip install flaml[nlp]\n",
|
||||
"```\n",
|
||||
"Our paper was developed under transformers version 3.4.0. We uninstall and reinstall transformers==3.4.0:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install flaml[nlp]\n",
|
||||
"!pip install transformers==3.4.0\n",
|
||||
"from flaml.nlp import AutoTransformers\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Initial Experimental Study\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load dataset \n",
|
||||
"\n",
|
||||
"Load the dataset using AutoTransformer.prepare_data. In this notebook, we use the Microsoft Research Paraphrasing Corpus (MRPC) dataset and the Electra model as an example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"console_args has no attribute pretrained_model_size, continue\n",
|
||||
"console_args has no attribute dataset_subdataset_name, continue\n",
|
||||
"console_args has no attribute algo_mode, continue\n",
|
||||
"console_args has no attribute space_mode, continue\n",
|
||||
"console_args has no attribute search_alg_args_mode, continue\n",
|
||||
"console_args has no attribute algo_name, continue\n",
|
||||
"console_args has no attribute pruner, continue\n",
|
||||
"console_args has no attribute resplit_mode, continue\n",
|
||||
"console_args has no attribute rep_id, continue\n",
|
||||
"console_args has no attribute seed_data, continue\n",
|
||||
"console_args has no attribute seed_transformers, continue\n",
|
||||
"console_args has no attribute learning_rate, continue\n",
|
||||
"console_args has no attribute weight_decay, continue\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Reusing dataset glue (/home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-6a78e5c95406457c.arrow\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-e8d0f3e04c3b4588.arrow\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-4b0966b394994163.arrow\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-6a78e5c95406457c.arrow\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-e8d0f3e04c3b4588.arrow\n",
|
||||
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-4b0966b394994163.arrow\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"autohf = AutoTransformers()\n",
|
||||
"preparedata_setting = {\n",
|
||||
" \"dataset_subdataset_name\": \"glue:mrpc\",\n",
|
||||
" \"pretrained_model_size\": \"google/electra-base-discriminator:base\",\n",
|
||||
" \"data_root_path\": \"data/\",\n",
|
||||
" \"max_seq_length\": 128,\n",
|
||||
" }\n",
|
||||
"autohf.prepare_data(**preparedata_setting)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"### Running grid search\n",
|
||||
"\n",
|
||||
"First, we run grid search using Electra. By specifying `algo_mode=\"grid\"`, AutoTransformers will run the grid search algorithm. By specifying `space_mode=\"grid\"`, AutoTransformers will use the default grid search configuration recommended by the Electra paper:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"== Status ==<br>Memory usage on this node: 14.2/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/250.73 GiB heap, 0.0/76.9 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 67d99_00002 with accuracy=0.7254901960784313 and parameters={'learning_rate': 0.0001, 'weight_decay': 0.0, 'adam_epsilon': 1e-06, 'warmup_ratio': 0.1, 'per_device_train_batch_size': 32, 'hidden_dropout_prob': 0.1, 'attention_probs_dropout_prob': 0.1, 'num_train_epochs': 0.5, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=grid_spa=grid_arg=dft_alg=grid_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 4/4 (4 TERMINATED)<br><br>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-06-16 10:45:35,071\tINFO tune.py:450 -- Total run time: 106.56 seconds (106.41 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total running time: 106.57789206504822 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import transformers\n",
|
||||
"autohf_settings = {\n",
|
||||
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
|
||||
" \"num_samples\": 1,\n",
|
||||
" \"time_budget\": 100000, # unlimited time budget\n",
|
||||
" \"ckpt_per_epoch\": 5,\n",
|
||||
" \"fp16\": True,\n",
|
||||
" \"algo_mode\": \"grid\", # set the search algorithm to grid search\n",
|
||||
" \"space_mode\": \"grid\", # set the search space to the recommended grid space\n",
|
||||
" \"transformers_verbose\": transformers.logging.ERROR\n",
|
||||
" }\n",
|
||||
"validation_metric, analysis = autohf.fit(**autohf_settings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Get the time for running grid search: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"grid search for glue_mrpc took 106.57789206504822 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"GST = autohf.last_run_duration\n",
|
||||
"print(\"grid search for {} took {} seconds\".format(autohf.jobid_config.get_jobid_full_data_name(), GST))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After the HPO run finishes, generate the predictions and save it as a .zip file to be submitted to the glue website. Here we will need the library AzureUtils which is for storing the output information (e.g., analysis log, .zip file) locally and uploading the output to an azure blob container (e.g., if multiple jobs are executed in a cluster). If the azure key and container information is not specified, the output information will only be saved locally. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"remove_columns_ is deprecated and will be removed in the next major version of datasets. Use the dataset.remove_columns method instead.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cleaning the existing label column from test data\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [432/432 00:34]\n",
|
||||
" </div>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"JobID(dat=['glue'], subdat='mrpc', mod='grid', spa='grid', arg='dft', alg='grid', pru='None', pre_full='google/electra-base-discriminator', pre='electra', presz='base', spt='ori', rep=0, sddt=43, sdhf=42, var1=None, var2=None)\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"{'eval_accuracy': 0.7254901960784313, 'eval_f1': 0.8276923076923076, 'eval_loss': 0.516851007938385}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predictions, test_metric = autohf.predict()\n",
|
||||
"from flaml.nlp import AzureUtils\n",
|
||||
"\n",
|
||||
"print(autohf.jobid_config)\n",
|
||||
"\n",
|
||||
"azure_utils = AzureUtils(root_log_path=\"logs_test/\", autohf=autohf)\n",
|
||||
"azure_utils.write_autohf_output(valid_metric=validation_metric,\n",
|
||||
" predictions=predictions,\n",
|
||||
" duration=GST)\n",
|
||||
"print(validation_metric)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The validation F1/accuracy we got was 92.4/89.5. After the above steps, you will find a .zip file for the predictions under data/result/. Submit the .zip file to the glue website. The test F1/accuracy we got was 90.4/86.7. As an example, we only run the experiment one time, but in general, we should run the experiment multiple repetitions and report the averaged validation and test accuracy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"### Running Random Search\n",
|
||||
"\n",
|
||||
"Next, we run random search with the same time budget as grid search:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_hpo(time_budget, this_hpo_space):\n",
|
||||
" autohf_settings = {\n",
|
||||
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
|
||||
" \"num_samples\": -1,\n",
|
||||
" \"time_budget\": time_budget,\n",
|
||||
" \"ckpt_per_epoch\": 5,\n",
|
||||
" \"fp16\": True,\n",
|
||||
" \"algo_mode\": \"hpo\", # set the search algorithm mode to hpo\n",
|
||||
" \"algo_name\": \"rs\",\n",
|
||||
" \"space_mode\": \"cus\", # customized search space (this_hpo_space)\n",
|
||||
" \"hpo_space\": this_hpo_space,\n",
|
||||
" \"transformers_verbose\": transformers.logging.ERROR\n",
|
||||
" }\n",
|
||||
" validation_metric, analysis = autohf.fit(**autohf_settings)\n",
|
||||
" predictions, test_metric = autohf.predict()\n",
|
||||
" azure_utils = AzureUtils(root_log_path=\"logs_test/\", autohf=autohf)\n",
|
||||
" azure_utils.write_autohf_output(valid_metric=validation_metric,\n",
|
||||
" predictions=predictions,\n",
|
||||
" duration=GST)\n",
|
||||
" print(validation_metric)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"== Status ==<br>Memory usage on this node: 30.1/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.51 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: c67b4_00003 with accuracy=0.7303921568627451 and parameters={'learning_rate': 4.030097060410288e-05, 'warmup_ratio': 0.06084844859190755, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.15742692948967135, 'attention_probs_dropout_prob': 0.08638900372842316, 'hidden_dropout_prob': 0.058245828039608386, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 8/infinite (8 TERMINATED)<br><br>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
||||
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
|
||||
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
|
||||
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-06-16 10:48:21,624\tINFO tune.py:450 -- Total run time: 114.32 seconds (109.41 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total running time: 114.35665488243103 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [432/432 00:33]\n",
|
||||
" </div>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"{'eval_accuracy': 0.7328431372549019, 'eval_f1': 0.8320493066255777, 'eval_loss': 0.5411379933357239}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"hpo_space_full = {\n",
|
||||
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
|
||||
" \"warmup_ratio\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
|
||||
" \"num_train_epochs\": [3],\n",
|
||||
" \"per_device_train_batch_size\": [16, 32, 64],\n",
|
||||
" \"weight_decay\": {\"l\": 0.0, \"u\": 0.3, \"space\": \"linear\"},\n",
|
||||
" \"attention_probs_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
|
||||
" \"hidden_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"tune_hpo(GST, hpo_space_full)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The validation F1/accuracy we got was 93.5/90.9. Similarly, we can submit the .zip file to the glue website. The test F1/accuaracy we got was 81.6/70.2. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## 3. Troubleshooting HPO Failures\n",
|
||||
"\n",
|
||||
"Since the validation accuracy is larger than grid search while the test accuracy is smaller, HPO has overfitting. We reduce the search space:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"== Status ==<br>Memory usage on this node: 26.5/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.51 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 234d8_00003 with accuracy=0.7475490196078431 and parameters={'learning_rate': 0.00011454435497690623, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.06370173320348284, 'attention_probs_dropout_prob': 0.03636499344142013, 'hidden_dropout_prob': 0.03668090197068676, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 6/infinite (6 TERMINATED)<br><br>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-06-16 10:51:34,598\tINFO tune.py:450 -- Total run time: 151.57 seconds (136.77 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total running time: 151.59901237487793 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [432/432 00:33]\n",
|
||||
" </div>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"{'eval_accuracy': 0.7475490196078431, 'eval_f1': 0.8325203252032519, 'eval_loss': 0.5056071877479553}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"hpo_space_fixwr = {\n",
|
||||
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
|
||||
" \"warmup_ratio\": [0.1],\n",
|
||||
" \"num_train_epochs\": [3],\n",
|
||||
" \"per_device_train_batch_size\": [16, 32, 64],\n",
|
||||
" \"weight_decay\": {\"l\": 0.0, \"u\": 0.3, \"space\": \"linear\"},\n",
|
||||
" \"attention_probs_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
|
||||
" \"hidden_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
|
||||
" }\n",
|
||||
"tune_hpo(GST, hpo_space_fixwr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The validation F1/accuracy we got was 92.6/89.7, the test F1/accuracy was 85.9/78.7, therefore overfitting still exists and we further reduce the space: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"== Status ==<br>Memory usage on this node: 29.6/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.46 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 96a67_00003 with accuracy=0.7107843137254902 and parameters={'learning_rate': 7.862589064613256e-05, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 32, 'weight_decay': 0.0, 'attention_probs_dropout_prob': 0.1, 'hidden_dropout_prob': 0.1, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 6/infinite (6 TERMINATED)<br><br>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-06-16 10:54:14,542\tINFO tune.py:450 -- Total run time: 117.99 seconds (112.99 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total running time: 118.01927375793457 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [432/432 00:33]\n",
|
||||
" </div>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"{'eval_accuracy': 0.7181372549019608, 'eval_f1': 0.8174962292609351, 'eval_loss': 0.5494586229324341}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"hpo_space_min = {\n",
|
||||
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
|
||||
" \"warmup_ratio\": [0.1],\n",
|
||||
" \"num_train_epochs\": [3],\n",
|
||||
" \"per_device_train_batch_size\": [16, 32, 64],\n",
|
||||
" \"weight_decay\": [0.0],\n",
|
||||
" \"attention_probs_dropout_prob\": [0.1],\n",
|
||||
" \"hidden_dropout_prob\": [0.1],\n",
|
||||
" }\n",
|
||||
"tune_hpo(GST, hpo_space_min)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The validation F1/accuracy we got was 90.4/86.7, test F1/accuracy was 83.0/73.0. Since the validation accuracy is below grid search, we increase the budget to 4 * GST:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"== Status ==<br>Memory usage on this node: 26.2/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.46 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: f5d31_00005 with accuracy=0.7352941176470589 and parameters={'learning_rate': 3.856175093679045e-05, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.0, 'attention_probs_dropout_prob': 0.1, 'hidden_dropout_prob': 0.1, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 16/infinite (16 TERMINATED)<br><br>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
|
||||
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-06-16 11:03:23,308\tINFO tune.py:450 -- Total run time: 507.09 seconds (445.79 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total running time: 507.15925645828247 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" <style>\n",
|
||||
" /* Turns off some styling */\n",
|
||||
" progress {\n",
|
||||
" /* gets rid of default border in Firefox and Opera. */\n",
|
||||
" border: none;\n",
|
||||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
||||
" background-size: auto;\n",
|
||||
" }\n",
|
||||
" </style>\n",
|
||||
" \n",
|
||||
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [432/432 00:34]\n",
|
||||
" </div>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
|
||||
"Your output will not be synced to azure because azure key and container name are not specified\n",
|
||||
"{'eval_accuracy': 0.7401960784313726, 'eval_f1': 0.8333333333333334, 'eval_loss': 0.5303606986999512}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"hpo_space_min = {\n",
|
||||
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
|
||||
" \"warmup_ratio\": [0.1],\n",
|
||||
" \"num_train_epochs\": [3],\n",
|
||||
" \"per_device_train_batch_size\": [16, 32, 64],\n",
|
||||
" \"weight_decay\": [0.0],\n",
|
||||
" \"attention_probs_dropout_prob\": [0.1],\n",
|
||||
" \"hidden_dropout_prob\": [0.1],\n",
|
||||
" }\n",
|
||||
"tune_hpo(4 * GST, hpo_space_min)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The validation F1/accuracy we got was 92.3/89.7, where the accuracy outperforms grid search. The test F1/accuracy was 91.1/87.8. As a result, random search with 4*GST and the minimal search space `hpo_space_min` has outperformed grid search. We stop the troubleshooting process. "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
|
@ -109,7 +109,7 @@ def test_azureutils():
|
|||
configscore_list.sorted(each_method)
|
||||
configscore_list.get_best_config()
|
||||
|
||||
azureutils = AzureUtils(console_args=args, autohf=autohf)
|
||||
azureutils = AzureUtils(azure_key_path=args.key_path, data_root_dir=args.data_root_dir, autohf=autohf)
|
||||
azureutils.autohf = autohf
|
||||
azureutils.root_log_path = "logs_azure/"
|
||||
|
||||
|
|
|
@ -198,7 +198,7 @@ def test_wandb_utils():
|
|||
args.key_path = "."
|
||||
jobid_config = JobID(args)
|
||||
|
||||
wandb_utils = WandbUtils(is_wandb_on=True, console_args=args, jobid_config=jobid_config)
|
||||
wandb_utils = WandbUtils(is_wandb_on=True, wandb_key_path=args.key_path, jobid_config=jobid_config)
|
||||
os.environ["WANDB_MODE"] = "online"
|
||||
wandb_utils.wandb_group_name = "test"
|
||||
wandb_utils._get_next_trial_ids()
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_hpo():
|
|||
if test_metric:
|
||||
validation_metric.update({"test": test_metric})
|
||||
|
||||
azure_utils = AzureUtils(root_log_path="logs_test/", autohf=autohf)
|
||||
azure_utils = AzureUtils(root_log_path="logs_test/", data_root_dir="data/", autohf=autohf)
|
||||
azure_utils._azure_key = "test"
|
||||
azure_utils._container_name = "test"
|
||||
|
||||
|
|
Loading…
Reference in New Issue