mirror of https://github.com/microsoft/autogen.git
Finish the Multiple Choice Classification (#367)
* adding multiple choice * update test cases (hard coded) * merged common code in predict_proba and predict in TransformersEstimator
This commit is contained in:
parent
2f5d6169d3
commit
9c00e4272a
30
docs/conf.py
30
docs/conf.py
|
@ -17,9 +17,9 @@
|
|||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'FLAML'
|
||||
copyright = '2020-2021, FLAML Team'
|
||||
author = 'FLAML Team'
|
||||
project = "FLAML"
|
||||
copyright = "2020-2021, FLAML Team"
|
||||
author = "FLAML Team"
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
@ -28,23 +28,23 @@ author = 'FLAML Team'
|
|||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.githubpages',
|
||||
'sphinx_rtd_theme',
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.githubpages",
|
||||
"sphinx_rtd_theme",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
@ -52,9 +52,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
|||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
'''!
|
||||
"""!
|
||||
* Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
'''
|
||||
"""
|
||||
|
||||
N_SPLITS = 5
|
||||
RANDOM_SEED = 1
|
||||
|
|
|
@ -14,7 +14,14 @@ from typing import Dict, Union, List
|
|||
|
||||
# TODO: if your task is not specified in here, define your task as an all-capitalized word
|
||||
SEQCLASSIFICATION = "seq-classification"
|
||||
CLASSIFICATION = ("binary", "multi", "classification", SEQCLASSIFICATION)
|
||||
MULTICHOICECLASSIFICATION = "multichoice-classification"
|
||||
CLASSIFICATION = (
|
||||
"binary",
|
||||
"multi",
|
||||
"classification",
|
||||
SEQCLASSIFICATION,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
)
|
||||
SEQREGRESSION = "seq-regression"
|
||||
REGRESSION = ("regression", SEQREGRESSION)
|
||||
TS_FORECAST = "ts_forecast"
|
||||
|
@ -26,6 +33,7 @@ NLG_TASKS = (SUMMARIZATION,)
|
|||
NLU_TASKS = (
|
||||
SEQREGRESSION,
|
||||
SEQCLASSIFICATION,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ from .data import (
|
|||
SEQREGRESSION,
|
||||
SUMMARIZATION,
|
||||
NLG_TASKS,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
|
@ -409,6 +410,7 @@ class TransformersEstimator(BaseEstimator):
|
|||
# from .nlp.huggingface.trainer import Seq2SeqTrainerForAuto as TrainerForAuto
|
||||
# else:
|
||||
from .nlp.huggingface.trainer import TrainerForAuto
|
||||
from .nlp.huggingface.data_collator import DataCollatorForAuto
|
||||
|
||||
this_params = self.params
|
||||
|
||||
|
@ -563,6 +565,12 @@ class TransformersEstimator(BaseEstimator):
|
|||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=DataCollatorForAuto(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
if self._task == MULTICHOICECLASSIFICATION
|
||||
else None,
|
||||
compute_metrics=self._compute_metrics_by_dataset_name,
|
||||
callbacks=[EarlyStoppingCallbackForAuto],
|
||||
)
|
||||
|
@ -658,41 +666,15 @@ class TransformersEstimator(BaseEstimator):
|
|||
)
|
||||
return metric_dict
|
||||
|
||||
def predict_proba(self, X_test):
|
||||
assert (
|
||||
self._task in CLASSIFICATION
|
||||
), "predict_proba() only for classification tasks."
|
||||
|
||||
def _init_model_for_predict(self, X_test):
|
||||
from datasets import Dataset
|
||||
from .nlp.huggingface.trainer import TrainerForAuto
|
||||
from transformers import TrainingArguments
|
||||
from .nlp.utils import load_model
|
||||
from transformers import AutoTokenizer
|
||||
from .nlp.huggingface.trainer import TrainerForAuto
|
||||
from .nlp.huggingface.data_collator import DataCollatorForPredict
|
||||
|
||||
X_test, _ = self._preprocess(X_test, **self._kwargs)
|
||||
test_dataset = Dataset.from_pandas(X_test)
|
||||
|
||||
best_model = load_model(
|
||||
checkpoint_path=self._checkpoint_path,
|
||||
task=self._task,
|
||||
num_labels=self._num_labels,
|
||||
per_model_config=self._per_model_config,
|
||||
)
|
||||
training_args = TrainingArguments(
|
||||
per_device_eval_batch_size=1,
|
||||
output_dir=self.custom_hpo_args.output_dir,
|
||||
)
|
||||
self._model = TrainerForAuto(model=best_model, args=training_args)
|
||||
predictions = self._model.predict(test_dataset)
|
||||
return predictions.predictions
|
||||
|
||||
def predict(self, X_test):
|
||||
from datasets import Dataset
|
||||
from .nlp.utils import load_model
|
||||
from .nlp.huggingface.trainer import TrainerForAuto
|
||||
|
||||
X_test, _ = self._preprocess(X=X_test, **self._kwargs)
|
||||
test_dataset = Dataset.from_pandas(X_test)
|
||||
|
||||
best_model = load_model(
|
||||
checkpoint_path=self._checkpoint_path,
|
||||
task=self._task,
|
||||
|
@ -704,7 +686,32 @@ class TransformersEstimator(BaseEstimator):
|
|||
output_dir=self.custom_hpo_args.output_dir,
|
||||
**self._training_args_config,
|
||||
)
|
||||
self._model = TrainerForAuto(model=best_model, args=training_args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.custom_hpo_args.model_path, use_fast=True
|
||||
)
|
||||
self._model = TrainerForAuto(
|
||||
model=best_model,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForPredict(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
if self._task == MULTICHOICECLASSIFICATION
|
||||
else None,
|
||||
)
|
||||
return test_dataset, training_args
|
||||
|
||||
def predict_proba(self, X_test):
|
||||
assert (
|
||||
self._task in CLASSIFICATION
|
||||
), "predict_proba() only for classification tasks."
|
||||
|
||||
test_dataset, _ = self._init_model_for_predict(X_test)
|
||||
predictions = self._model.predict(test_dataset)
|
||||
return predictions.predictions
|
||||
|
||||
def predict(self, X_test):
|
||||
test_dataset, training_args = self._init_model_for_predict(X_test)
|
||||
if self._task not in NLG_TASKS:
|
||||
predictions = self._model.predict(test_dataset)
|
||||
else:
|
||||
|
@ -728,6 +735,8 @@ class TransformersEstimator(BaseEstimator):
|
|||
predictions, skip_special_tokens=True
|
||||
)
|
||||
return decoded_preds
|
||||
elif self._task == MULTICHOICECLASSIFICATION:
|
||||
return np.argmax(predictions.predictions, axis=1)
|
||||
|
||||
def config2params(self, config: dict) -> dict:
|
||||
params = config.copy()
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
from dataclasses import dataclass
|
||||
from transformers.data.data_collator import DataCollatorWithPadding
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForAuto(DataCollatorWithPadding):
|
||||
def __call__(self, features):
|
||||
from itertools import chain
|
||||
import torch
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
|
||||
for feature in features
|
||||
]
|
||||
flattened_features = list(chain(*flattened_features))
|
||||
batch = super(DataCollatorForAuto, self).__call__(flattened_features)
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
|
||||
class DataCollatorForPredict(DataCollatorWithPadding):
|
||||
def __call__(self, features):
|
||||
from itertools import chain
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
|
||||
for feature in features
|
||||
]
|
||||
flattened_features = list(chain(*flattened_features))
|
||||
batch = super(DataCollatorForPredict, self).__call__(flattened_features)
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
return batch
|
|
@ -1,7 +1,15 @@
|
|||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, Any
|
||||
from ..data import SUMMARIZATION, SEQREGRESSION, SEQCLASSIFICATION, NLG_TASKS
|
||||
|
||||
from ..data import (
|
||||
SUMMARIZATION,
|
||||
SEQREGRESSION,
|
||||
SEQCLASSIFICATION,
|
||||
NLG_TASKS,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
)
|
||||
|
||||
|
||||
def load_default_huggingface_metric_for_task(task):
|
||||
|
@ -11,6 +19,8 @@ def load_default_huggingface_metric_for_task(task):
|
|||
return "rmse", "max"
|
||||
elif task == SUMMARIZATION:
|
||||
return "rouge", "max"
|
||||
elif task == MULTICHOICECLASSIFICATION:
|
||||
return "accuracy"
|
||||
# TODO: elif task == your task, return the default metric name for your task,
|
||||
# e.g., if task == MULTIPLECHOICE, return "accuracy"
|
||||
# notice this metric name has to be in ['accuracy', 'bertscore', 'bleu', 'bleurt',
|
||||
|
@ -32,6 +42,8 @@ def tokenize_text(X, Y=None, task=None, custom_hpo_args=None):
|
|||
return X_tokenized, None
|
||||
elif task in NLG_TASKS:
|
||||
return tokenize_seq2seq(X, Y, task=task, custom_hpo_args=custom_hpo_args)
|
||||
elif task == MULTICHOICECLASSIFICATION:
|
||||
return tokenize_text_multiplechoice(X, custom_hpo_args)
|
||||
|
||||
|
||||
def tokenize_seq2seq(X, Y, task=None, custom_hpo_args=None):
|
||||
|
@ -140,6 +152,59 @@ def tokenize_row(
|
|||
return [tokenized_example[x] for x in tokenized_column_names]
|
||||
|
||||
|
||||
def tokenize_text_multiplechoice(X, custom_hpo_args):
|
||||
from transformers import AutoTokenizer
|
||||
import pandas
|
||||
|
||||
global tokenized_column_names
|
||||
|
||||
this_tokenizer = AutoTokenizer.from_pretrained(
|
||||
custom_hpo_args.model_path, # 'roberta-base'
|
||||
cache_dir=None,
|
||||
use_fast=True,
|
||||
revision="main",
|
||||
use_auth_token=None,
|
||||
)
|
||||
t = X[["sent1", "sent2", "ending0", "ending1", "ending2", "ending3"]]
|
||||
d = t.apply(
|
||||
lambda x: tokenize_swag(x, this_tokenizer, custom_hpo_args),
|
||||
axis=1,
|
||||
result_type="expand",
|
||||
)
|
||||
|
||||
X_tokenized = pandas.DataFrame(columns=tokenized_column_names)
|
||||
X_tokenized[tokenized_column_names] = d
|
||||
output = X_tokenized.join(X)
|
||||
return output, None
|
||||
|
||||
|
||||
def tokenize_swag(this_row, this_tokenizer, custom_hpo_args):
|
||||
global tokenized_column_names
|
||||
|
||||
first_sentences = [[this_row["sent1"]] * 4]
|
||||
# get each 1st sentence, multiply to 4 sentences
|
||||
question_headers = this_row["sent2"]
|
||||
# sent2 are the noun part of 2nd line
|
||||
second_sentences = [
|
||||
question_headers + " " + this_row[key]
|
||||
for key in ["ending0", "ending1", "ending2", "ending3"]
|
||||
]
|
||||
# now the 2nd-sentences are formed by combing the noun part and 4 ending parts
|
||||
|
||||
# Flatten out
|
||||
# From 2 dimension to 1 dimension array
|
||||
first_sentences = list(chain(*first_sentences))
|
||||
|
||||
tokenized_example = this_tokenizer(
|
||||
*tuple([first_sentences, second_sentences]),
|
||||
truncation=True,
|
||||
max_length=custom_hpo_args.max_seq_length,
|
||||
padding=False,
|
||||
)
|
||||
tokenized_column_names = sorted(tokenized_example.keys())
|
||||
return [tokenized_example[x] for x in tokenized_column_names]
|
||||
|
||||
|
||||
def separate_config(config, task):
|
||||
if task in NLG_TASKS:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainingArguments
|
||||
|
@ -248,15 +313,24 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
|||
def get_this_model(task):
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
from transformers import AutoModelForMultipleChoice
|
||||
|
||||
if task in (SEQCLASSIFICATION, SEQREGRESSION):
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
checkpoint_path, config=model_config
|
||||
)
|
||||
# TODO: elif task == your task, fill in the line in your transformers example
|
||||
# that loads the model, e.g., if task == MULTIPLE CHOICE, according to
|
||||
# https://github.com/huggingface/transformers/blob/master/examples/pytorch/multiple-choice/run_swag.py#L298
|
||||
# you can return AutoModelForMultipleChoice.from_pretrained(checkpoint_path, config=model_config)
|
||||
elif task in NLG_TASKS:
|
||||
return AutoModelForSeq2SeqLM.from_pretrained(
|
||||
checkpoint_path, config=model_config
|
||||
)
|
||||
elif task == MULTICHOICECLASSIFICATION:
|
||||
return AutoModelForMultipleChoice.from_pretrained(
|
||||
checkpoint_path, config=model_config
|
||||
)
|
||||
|
||||
def is_pretrained_model_in_classification_head_list(model_type):
|
||||
return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
|
||||
|
|
|
@ -1,2 +1,6 @@
|
|||
from .trial_scheduler import TrialScheduler
|
||||
from .online_scheduler import OnlineScheduler, OnlineSuccessiveDoublingScheduler, ChaChaScheduler
|
||||
from .online_scheduler import (
|
||||
OnlineScheduler,
|
||||
OnlineSuccessiveDoublingScheduler,
|
||||
ChaChaScheduler,
|
||||
)
|
||||
|
|
2
setup.py
2
setup.py
|
@ -3,7 +3,7 @@ import os
|
|||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
with open("README.md", "r", encoding="UTF-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "darwin", reason="do not run on mac os")
|
||||
def test_mcc():
|
||||
from flaml import AutoML
|
||||
|
||||
import pandas as pd
|
||||
|
||||
train_data = {'video-id': ['anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_MldEr60j33M', 'lsmdc0049_Hannah_and_her_sisters-69438'],
|
||||
'fold-ind': ['10030', '10030', '10030', '5488', '17405'],
|
||||
'startphrase': ['A woman is seen running down a long track and jumping into a pit. The camera',
|
||||
'A woman is seen running down a long track and jumping into a pit. The camera',
|
||||
'A woman is seen running down a long track and jumping into a pit. The camera',
|
||||
'A man in a white shirt bends over and picks up a large weight. He',
|
||||
'Someone furiously shakes someone away. He'],
|
||||
'sent1': ['A woman is seen running down a long track and jumping into a pit.',
|
||||
'A woman is seen running down a long track and jumping into a pit.',
|
||||
'A woman is seen running down a long track and jumping into a pit.',
|
||||
'A man in a white shirt bends over and picks up a large weight.',
|
||||
'Someone furiously shakes someone away.'],
|
||||
'sent2': ['The camera', 'The camera', 'The camera', 'He', 'He'],
|
||||
'gold-source': ['gen', 'gen', 'gold', 'gen', 'gold'],
|
||||
'ending0': ['captures her as well as lifting weights down in place.',
|
||||
'follows her spinning her body around and ends by walking down a lane.',
|
||||
'watches her as she walks away and sticks her tongue out to another person.',
|
||||
'lifts the weights over his head.',
|
||||
'runs to a woman standing waiting.'],
|
||||
'ending1': ['pans up to show another woman running down the track.',
|
||||
'pans around the two.',
|
||||
'captures her as well as lifting weights down in place.',
|
||||
'also lifts it onto his chest before hanging it back out again.',
|
||||
'tackles him into the passenger seat.'],
|
||||
'ending2': ['follows her movements as the group members follow her instructions.',
|
||||
'captures her as well as lifting weights down in place.',
|
||||
'follows her spinning her body around and ends by walking down a lane.',
|
||||
'spins around and lifts a barbell onto the floor.',
|
||||
'pounds his fist against a cupboard.'],
|
||||
'ending3': ['follows her spinning her body around and ends by walking down a lane.',
|
||||
'follows her movements as the group members follow her instructions.',
|
||||
'pans around the two.',
|
||||
'bends down and lifts the weight over his head.',
|
||||
'offers someone the cup on his elbow and strides out.'],
|
||||
'label': [1, 3, 0, 0, 2]}
|
||||
dev_data = {'video-id': ['lsmdc3001_21_JUMP_STREET-422',
|
||||
'lsmdc0001_American_Beauty-45991',
|
||||
'lsmdc0001_American_Beauty-45991',
|
||||
'lsmdc0001_American_Beauty-45991'],
|
||||
'fold-ind': ['11783', '10977', '10970', '10968'],
|
||||
'startphrase': ['Firing wildly he shoots holes through the tanker. He',
|
||||
'He puts his spatula down. The Mercedes',
|
||||
'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf. Someone',
|
||||
"He starts going through someone's bureau. He opens the drawer in which we know someone keeps his marijuana, but he"],
|
||||
'sent1': ['Firing wildly he shoots holes through the tanker.',
|
||||
'He puts his spatula down.',
|
||||
'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf.',
|
||||
"He starts going through someone's bureau."],
|
||||
'sent2': ['He', 'The Mercedes', 'Someone', 'He opens the drawer in which we know someone keeps his marijuana, but he'],
|
||||
'gold-source': ['gold', 'gold', 'gold', 'gold'],
|
||||
'ending0': ['overtakes the rig and falls off his bike.',
|
||||
'fly open and drinks.',
|
||||
"looks at someone's papers.",
|
||||
'stops one down and rubs a piece of the gift out.'],
|
||||
'ending1': ['squeezes relentlessly on the peanut jelly as well.',
|
||||
'walks off followed driveway again.',
|
||||
'feels around it and falls in the seat once more.',
|
||||
'cuts the mangled parts.'],
|
||||
'ending2': ['scrambles behind himself and comes in other directions.',
|
||||
'slots them into a separate green.',
|
||||
'sprints back from the wreck and drops onto his back.',
|
||||
'hides it under his hat to watch.'],
|
||||
'ending3': ['sweeps a explodes and knocks someone off.',
|
||||
'pulls around to the drive - thru window.',
|
||||
'sits at the kitchen table, staring off into space.',
|
||||
"does n't discover its false bottom."],
|
||||
'label': [0, 3, 3, 3]}
|
||||
test_data = {'video-id': ['lsmdc0001_American_Beauty-45991',
|
||||
'lsmdc0001_American_Beauty-45991',
|
||||
'lsmdc0001_American_Beauty-45991',
|
||||
'lsmdc0001_American_Beauty-45991'],
|
||||
'fold-ind': ['10980', '10976', '10978', '10969'],
|
||||
'startphrase': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food. The Counter Girl',
|
||||
'Someone looks up suddenly when he hears. He',
|
||||
'Someone drives; someone sits beside her. They',
|
||||
"He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom. He stands and looks around, his eyes"],
|
||||
'sent1': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food.',
|
||||
'Someone looks up suddenly when he hears.',
|
||||
'Someone drives; someone sits beside her.',
|
||||
"He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom."],
|
||||
'sent2': ['The Counter Girl', 'He', 'They', 'He stands and looks around, his eyes'],
|
||||
'gold-source': ['gold', 'gold', 'gold', 'gold'],
|
||||
'ending0': ['stands next to him, staring blankly.',
|
||||
'puts his spatula down.',
|
||||
"rise someone's feet up.",
|
||||
'moving to the side, the houses rapidly stained.'],
|
||||
'ending1': ['with auditorium, filmed, singers the club.',
|
||||
'bumps into a revolver and drops surreptitiously into his weapon.',
|
||||
'lift her and they are alarmed.',
|
||||
'focused as the sight of someone making his way down a trail.'],
|
||||
'ending2': ['attempts to block her ransacked.',
|
||||
'talks using the phone and walks away for a few seconds.',
|
||||
'are too involved with each other to notice someone watching them from the drive - thru window.',
|
||||
'finally landing on: the digicam and a stack of cassettes on a shelf.'],
|
||||
'ending3': ['is eating solid and stinky.',
|
||||
'bundles the flaxen powder beneath the car.',
|
||||
'sit at a table with a beer from a table.',
|
||||
"deep and continuing, its bleed - length sideburns pressing on him."],
|
||||
'label': [0, 0, 2, 2]}
|
||||
|
||||
train_dataset = pd.DataFrame(train_data)
|
||||
dev_dataset = pd.DataFrame(dev_data)
|
||||
test_dataset = pd.DataFrame(test_data)
|
||||
|
||||
custom_sent_keys = [
|
||||
"sent1",
|
||||
"sent2",
|
||||
"ending0",
|
||||
"ending1",
|
||||
"ending2",
|
||||
"ending3",
|
||||
"gold-source",
|
||||
"video-id",
|
||||
"startphrase",
|
||||
"fold-ind",
|
||||
]
|
||||
label_key = "label"
|
||||
|
||||
X_train = train_dataset[custom_sent_keys]
|
||||
y_train = train_dataset[label_key]
|
||||
|
||||
X_val = dev_dataset[custom_sent_keys]
|
||||
y_val = dev_dataset[label_key]
|
||||
|
||||
X_test = test_dataset[custom_sent_keys]
|
||||
X_true = test_dataset[label_key]
|
||||
automl = AutoML()
|
||||
|
||||
automl_settings = {
|
||||
"gpu_per_trial": 0,
|
||||
"max_iter": 2,
|
||||
"time_budget": 5,
|
||||
"task": "multichoice-classification",
|
||||
"metric": "accuracy",
|
||||
"log_file_name": "seqclass.log",
|
||||
}
|
||||
|
||||
automl_settings["custom_hpo_args"] = {
|
||||
"model_path": "google/electra-small-discriminator",
|
||||
"output_dir": "test/data/output/",
|
||||
"ckpt_per_epoch": 5,
|
||||
"fp16": False,
|
||||
}
|
||||
|
||||
automl.fit(
|
||||
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
|
||||
)
|
||||
|
||||
y_pred = automl.predict(X_test)
|
||||
proba = automl.predict_proba(X_test)
|
||||
print(str(len(automl.classes_)) + " classes")
|
||||
print(y_pred)
|
||||
print(X_true)
|
||||
print(proba)
|
||||
true_count = 0
|
||||
for i, v in X_true.items():
|
||||
if y_pred[i] == v:
|
||||
true_count += 1
|
||||
accuracy = round(true_count / len(y_pred), 5)
|
||||
print("Accuracy: " + str(accuracy))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mcc()
|
|
@ -1,19 +1,21 @@
|
|||
from azureml.core import Workspace, Experiment, ScriptRunConfig
|
||||
|
||||
ws = Workspace.from_config()
|
||||
|
||||
compute_target = ws.compute_targets['V100-4']
|
||||
compute_target = ws.compute_targets["V100-4"]
|
||||
# compute_target = ws.compute_targets['K80']
|
||||
command = [
|
||||
"pip install torch transformers datasets flaml[blendsearch,ray] && ",
|
||||
"python test_electra.py"]
|
||||
"python test_electra.py",
|
||||
]
|
||||
|
||||
config = ScriptRunConfig(
|
||||
source_directory='hf/',
|
||||
source_directory="hf/",
|
||||
command=command,
|
||||
compute_target=compute_target,
|
||||
)
|
||||
|
||||
exp = Experiment(ws, 'test-electra')
|
||||
exp = Experiment(ws, "test-electra")
|
||||
run = exp.submit(config)
|
||||
print(run.get_portal_url()) # link to ml.azure.com
|
||||
run.wait_for_completion(show_output=True)
|
||||
|
|
Loading…
Reference in New Issue