fixed API doc and import (#108)

* removed run_analysis.py, run_autohf.py, test_jupyter.py
This commit is contained in:
Xueqing Liu 2021-06-15 12:55:23 -04:00 committed by GitHub
parent 926589bdda
commit a5a5a4bc20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 26 additions and 18 deletions

View File

@ -51,6 +51,12 @@ NLP
.. autoclass:: flaml.nlp.AutoTransformers
:members:
.. autoclass:: flaml.nlp.AzureUtils
:members:
.. autoclass:: flaml.nlp.JobID
:members:
.. Indices and tables
.. ==================

View File

@ -1,3 +1,2 @@
from .hpo.hpo_searchspace import AutoHPOSearchSpace
from .autotransformers import AutoTransformers
from .result_analysis.azure_utils import AzureUtils, JobID

View File

@ -803,7 +803,7 @@ class AutoTransformers:
test_trainer = TrainerForAutoTransformers(best_model, training_args)
if self.jobid_config.spt == "ori":
if "label" in self.test_dataset.keys():
if "label" in self.test_dataset.features.keys():
self.test_dataset.remove_columns_("label")
print("Cleaning the existing label column from test data")

View File

@ -287,7 +287,7 @@ def get_electra_space(model_size_type=None,
"num_train_epochs": [3],
},
"glue_mrpc": {
"num_train_epochs": [0.2],
"num_train_epochs": [3],
},
"glue_cola": {
"num_train_epochs": [3],

View File

@ -65,7 +65,7 @@ class AutoSearchAlgorithm:
if not search_algo_name:
search_algo_name = "grid"
if search_algo_name in SEARCH_ALGO_MAPPING.keys():
if search_algo_name == "grid":
if SEARCH_ALGO_MAPPING[search_algo_name] is None:
return None
"""
filtering the customized args for hpo from custom_hpo_args, keep those

View File

@ -50,11 +50,13 @@ class ConfigScoreList:
self._config_score_list = sorted(self._config_score_list, key=lambda x: x.start_time, reverse=False)
else:
self._config_score_list = sorted(self._config_score_list,
key=lambda x: getattr(x, "metric_score")[metric_mode], reverse=True)
key=lambda x: getattr(x, "metric_score")
[metric_mode], reverse=True)
def get_best_config(self,
metric_mode="max"):
return max(self._config_score_list, key=lambda x: getattr(x, "metric_score")[metric_mode])
return max(self._config_score_list, key=lambda x: getattr(x, "metric_score")
[metric_mode])
@dataclass
@ -318,11 +320,11 @@ class JobID:
try:
try:
if each_key == "dataset_subdataset_name":
dataset_subdataset_name_format_check(getattr(console_args, each_key))
dataset_subdataset_name_format_check(JobID.get_attrval_from_arg_or_dict(console_args, each_key))
self.dat = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[0].split(",")
self.subdat = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[1]
elif each_key == "pretrained_model_size":
pretrained_model_size_format_check(getattr(console_args, each_key))
pretrained_model_size_format_check(JobID.get_attrval_from_arg_or_dict(console_args, each_key))
self.pre_full = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[0]
self.pre = JobID.extract_model_type(self.pre_full)
self.presz = JobID.get_attrval_from_arg_or_dict(console_args, each_key).split(":")[1]

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
def dataset_subdataset_name_format_check(val_str):
regex = re.compile(r"^[^:]*:[^:]*$")
if (val_str is not None) and (not regex.match(val_str)):
if (val_str is not None) and (not regex.search(val_str)):
raise argparse.ArgumentTypeError("dataset_subdataset_name must be in the format {data_name}:{subdata_name}")
return val_str

View File

@ -76,7 +76,7 @@ setuptools.setup(
"nlp": [
"ray[tune]>=1.2.0",
"transformers",
"datasets==1.4",
"datasets==1.4.1",
"torch"
]
},

View File

@ -40,7 +40,7 @@ def test_get_configblob_from_partial_jobid():
except ImportError:
return
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
each_blob_name = "dat=glue_subdat=cola_mod=grid_spa=cus_arg=dft_alg=grid" \
"_pru=None_pre=deberta_presz=large_spt=rspt_rep=0_sddt=43" \
"_sdhf=42_var1=1e-05_var2=0.0.json"
@ -70,7 +70,7 @@ def test_jobid():
except ImportError:
return
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
args = get_console_args()
jobid_config = JobID(args)
@ -89,7 +89,8 @@ def test_azureutils():
except ImportError:
return
from flaml.nlp.result_analysis.azure_utils import AzureUtils, ConfigScore, JobID, ConfigScoreList
from flaml.nlp import AzureUtils, JobID
from flaml.nlp.result_analysis.azure_utils import ConfigScore, ConfigScoreList
from flaml.nlp import AutoTransformers
args = get_console_args()

View File

@ -24,7 +24,7 @@ def get_console_args():
def model_init():
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
jobid_config = JobID()
jobid_config.set_unittest_config()
from flaml.nlp import AutoTransformers
@ -88,7 +88,7 @@ def test_gridsearch_space():
return
from flaml.nlp.hpo.grid_searchspace_auto import GRID_SEARCH_SPACE_MAPPING, AutoGridSearchSpace
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
jobid_config = JobID()
jobid_config.set_unittest_config()
@ -107,7 +107,7 @@ def test_hpo_space():
return
from flaml.nlp.hpo.hpo_searchspace import AutoHPOSearchSpace, HPO_SEARCH_SPACE_MAPPING
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
jobid_config = JobID()
jobid_config.set_unittest_config()
@ -168,7 +168,7 @@ def test_switch_head():
return
from flaml.nlp.huggingface.switch_head_auto import AutoSeqClassificationHead, MODEL_CLASSIFICATION_HEAD_MAPPING
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
jobid_config = JobID()
jobid_config.set_unittest_config()
checkpoint_path = jobid_config.pre_full
@ -191,7 +191,7 @@ def test_wandb_utils():
return
from flaml.nlp.result_analysis.wandb_utils import WandbUtils
from flaml.nlp.result_analysis.azure_utils import JobID
from flaml.nlp import JobID
import os
args = get_console_args()