modift note for loss scale

This commit is contained in:
liuyang_655 2021-08-28 05:14:30 -04:00
parent be009760cd
commit d663129e4a
2 changed files with 8 additions and 5 deletions

View File

@ -36,7 +36,7 @@ class LossScaleManager:
class FixedLossScaleManager(LossScaleManager):
"""
Fixed loss-scale manager.
Loss scale with a fixed value, inherits from LossScaleManager.
Args:
loss_scale (float): Loss scale. Note that if `drop_overflow_update` is set to False, the value of `loss_scale`
@ -107,7 +107,7 @@ class FixedLossScaleManager(LossScaleManager):
class DynamicLossScaleManager(LossScaleManager):
"""
Dynamic loss-scale manager.
Loss scale that dynamically adjusts itself, inherits from LossScaleManager.
Args:
init_loss_scale (float): Initialize loss scale. Default: 2**24.

View File

@ -56,9 +56,14 @@ def do_eval_onnx(dataset=None, num_class=15, assessment_method="accuracy"):
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
onnx_file_name = args_opt.export_file_name + '.onnx'
onnx_file_name = args_opt.export_file_name
if not args_opt.export_file_name.endswith('.onnx'):
onnx_file_name = onnx_file_name + '.onnx'
if not os.path.isabs(onnx_file_name):
onnx_file_name = os.getcwd() + '/' + onnx_file_name
if not os.path.exists(onnx_file_name):
raise ValueError("ONNX file not exists, please check onnx file has been saved and whether the "
"export_file_name is correct.")
sess = rt.InferenceSession(onnx_file_name)
input_name_0 = sess.get_inputs()[0].name
input_name_1 = sess.get_inputs()[1].name
@ -88,8 +93,6 @@ def run_classifier_onnx():
"""run classifier task for onnx model"""
if args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do onnx evaluation task")
if args_opt.onnx_file_path == "":
raise ValueError("'onnx_file_path' must be set when do onnx evaluation task")
assessment_method = args_opt.assessment_method.lower()
ds = create_classification_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
assessment_method=assessment_method,