forked from mindspore-Ecosystem/mindspore
modift note for loss scale
This commit is contained in:
parent
be009760cd
commit
d663129e4a
|
@ -36,7 +36,7 @@ class LossScaleManager:
|
|||
|
||||
class FixedLossScaleManager(LossScaleManager):
|
||||
"""
|
||||
Fixed loss-scale manager.
|
||||
Loss scale with a fixed value, inherits from LossScaleManager.
|
||||
|
||||
Args:
|
||||
loss_scale (float): Loss scale. Note that if `drop_overflow_update` is set to False, the value of `loss_scale`
|
||||
|
@ -107,7 +107,7 @@ class FixedLossScaleManager(LossScaleManager):
|
|||
|
||||
class DynamicLossScaleManager(LossScaleManager):
|
||||
"""
|
||||
Dynamic loss-scale manager.
|
||||
Loss scale that dynamically adjusts itself, inherits from LossScaleManager.
|
||||
|
||||
Args:
|
||||
init_loss_scale (float): Initialize loss scale. Default: 2**24.
|
||||
|
|
|
@ -56,9 +56,14 @@ def do_eval_onnx(dataset=None, num_class=15, assessment_method="accuracy"):
|
|||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
onnx_file_name = args_opt.export_file_name + '.onnx'
|
||||
onnx_file_name = args_opt.export_file_name
|
||||
if not args_opt.export_file_name.endswith('.onnx'):
|
||||
onnx_file_name = onnx_file_name + '.onnx'
|
||||
if not os.path.isabs(onnx_file_name):
|
||||
onnx_file_name = os.getcwd() + '/' + onnx_file_name
|
||||
if not os.path.exists(onnx_file_name):
|
||||
raise ValueError("ONNX file not exists, please check onnx file has been saved and whether the "
|
||||
"export_file_name is correct.")
|
||||
sess = rt.InferenceSession(onnx_file_name)
|
||||
input_name_0 = sess.get_inputs()[0].name
|
||||
input_name_1 = sess.get_inputs()[1].name
|
||||
|
@ -88,8 +93,6 @@ def run_classifier_onnx():
|
|||
"""run classifier task for onnx model"""
|
||||
if args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do onnx evaluation task")
|
||||
if args_opt.onnx_file_path == "":
|
||||
raise ValueError("'onnx_file_path' must be set when do onnx evaluation task")
|
||||
assessment_method = args_opt.assessment_method.lower()
|
||||
ds = create_classification_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
|
|
Loading…
Reference in New Issue