tuning mindspore_test_framework

This commit is contained in:
panyifeng 2020-05-14 17:26:52 +08:00
parent 8003a89a7b
commit 5a7abf0740
36 changed files with 107 additions and 134 deletions

View File

@ -22,7 +22,7 @@ from mindspore.ops import operations as P
from ..mindspore_test import mindspore_test from ..mindspore_test import mindspore_test
from ..pipeline.gradient.compare_gradient import pipeline_for_compare_inputs_grad_with_npy_for_case_by_case_config from ..pipeline.gradient.compare_gradient import pipeline_for_compare_inputs_grad_with_npy_for_case_by_case_config
# from ...vm_impl import * from ...vm_impl import *
verification_set = [ verification_set = [
('MatMul', { ('MatMul', {

View File

@ -32,11 +32,11 @@ class CheckExceptionsEC(IExectorComponent):
'error_keywords': ['TensorAdd', 'shape'] 'error_keywords': ['TensorAdd', 'shape']
} }
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f = function[keyword.block] f = self.function[keyword.block]
args = inputs[keyword.desc_inputs] args = self.inputs[keyword.desc_inputs]
e = function.get(keyword.exception, Exception) e = self.function.get(keyword.exception, Exception)
error_kws = function.get(keyword.error_keywords, None) error_kws = self.function.get(keyword.error_keywords, None)
try: try:
with pytest.raises(e) as exec_info: with pytest.raises(e) as exec_info:
f(*args) f(*args)

View File

@ -26,8 +26,8 @@ class CheckGradientForScalarFunctionEC(IExectorComponent):
Examples: Examples:
'block': scalar_function 'block': scalar_function
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, sampling_times, _ = \ f, args, delta, max_error, input_selector, output_selector, sampling_times, _ = \
get_grad_checking_options(function, inputs) get_grad_checking_options(self.function, self.inputs)
check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=ScalarGradChecker, check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=ScalarGradChecker,
input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times) input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times)

View File

@ -35,9 +35,9 @@ class CheckGradientWrtInputsEC(IExectorComponent):
key_act=None, key_act=None,
initializer_range=0.02) initializer_range=0.02)
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, \ f, args, delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output = get_grad_checking_options(function, inputs) sampling_times, reduce_output = get_grad_checking_options(self.function, self.inputs)
check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=OperationGradChecker, check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=OperationGradChecker,
input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times, input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times,
reduce_output=reduce_output) reduce_output=reduce_output)

View File

@ -35,9 +35,9 @@ class CheckGradientWrtParamsEC(IExectorComponent):
key_act=None, key_act=None,
initializer_range=0.02) initializer_range=0.02)
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, \ f, args, delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output = get_grad_checking_options(function, inputs) sampling_times, reduce_output = get_grad_checking_options(self.function, self.inputs)
check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=NNGradChecker, check_gradient(f, *args, delta=delta, max_error=max_error, grad_checker_class=NNGradChecker,
input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times, input_selector=input_selector, output_selector=output_selector, sampling_times=sampling_times,
reduce_output=reduce_output) reduce_output=reduce_output)

View File

@ -26,8 +26,8 @@ class CheckJacobianForScalarFunctionEC(IExectorComponent):
Examples: Examples:
'block': scalar_function 'block': scalar_function
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, _, _ = \ f, args, delta, max_error, input_selector, output_selector, _, _ = \
get_grad_checking_options(function, inputs) get_grad_checking_options(self.function, self.inputs)
check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=ScalarGradChecker, check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=ScalarGradChecker,
input_selector=input_selector, output_selector=output_selector) input_selector=input_selector, output_selector=output_selector)

View File

@ -35,8 +35,8 @@ class CheckJacobianWrtInputsEC(IExectorComponent):
key_act=None, key_act=None,
initializer_range=0.02) initializer_range=0.02)
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, _, _ = \ f, args, delta, max_error, input_selector, output_selector, _, _ = \
get_grad_checking_options(function, inputs) get_grad_checking_options(self.function, self.inputs)
check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=OperationGradChecker, check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=OperationGradChecker,
input_selector=input_selector, output_selector=output_selector) input_selector=input_selector, output_selector=output_selector)

View File

@ -35,8 +35,8 @@ class CheckJacobianWrtParamsEC(IExectorComponent):
key_act=None, key_act=None,
initializer_range=0.02) initializer_range=0.02)
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
f, args, delta, max_error, input_selector, output_selector, _, _ = \ f, args, delta, max_error, input_selector, output_selector, _, _ = \
get_grad_checking_options(function, inputs) get_grad_checking_options(self.function, self.inputs)
check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=NNGradChecker, check_jacobian(f, *args, delta=delta, max_error=max_error, grad_checker_class=NNGradChecker,
input_selector=input_selector, output_selector=output_selector) input_selector=input_selector, output_selector=output_selector)

View File

@ -32,13 +32,13 @@ class LossVerifierEC(IExectorComponent):
'loss_upper_bound': 0.03, 'loss_upper_bound': 0.03,
} }
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
model = function[keyword.block][keyword.model] model = self.function[keyword.block][keyword.model]
loss = function[keyword.block][keyword.loss] loss = self.function[keyword.block][keyword.loss]
opt = function[keyword.block][keyword.opt] opt = self.function[keyword.block][keyword.opt]
num_epochs = function[keyword.block][keyword.num_epochs] num_epochs = self.function[keyword.block][keyword.num_epochs]
loss_upper_bound = function[keyword.block][keyword.loss_upper_bound] loss_upper_bound = self.function[keyword.block][keyword.loss_upper_bound]
train_dataset = inputs[keyword.desc_inputs] train_dataset = self.inputs[keyword.desc_inputs]
model = Model(model, loss, opt) model = Model(model, loss, opt)
loss = model.train(num_epochs, train_dataset) loss = model.train(num_epochs, train_dataset)
assert loss.asnumpy().mean() <= loss_upper_bound assert loss.asnumpy().mean() <= loss_upper_bound

View File

@ -22,12 +22,12 @@ class IdentityEC(IExectorComponent):
""" """
Execute function/inputs. Execute function/inputs.
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
result_id = function[keyword.id] + '-' + inputs[keyword.id] result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
group = function[keyword.group] + '-' + inputs[keyword.group] group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
return { return {
keyword.id: result_id, keyword.id: result_id,
keyword.group: group, keyword.group: group,
keyword.desc_inputs: inputs[keyword.desc_inputs], keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: function[keyword.block](*inputs[keyword.desc_inputs]) keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
} }

View File

@ -22,15 +22,15 @@ class IdentityBackwardEC(IExectorComponent):
""" """
Execute function/inputs, with all bprops attached, the bprop function created by BC should handle these bprops. Execute function/inputs, with all bprops attached, the bprop function created by BC should handle these bprops.
""" """
def run_function(self, function, inputs, verification_set): def __call__(self):
result_id = function[keyword.id] + '-' + inputs[keyword.id] result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id]
group = function[keyword.group] + '-' + inputs[keyword.group] group = self.function[keyword.group] + '-' + self.inputs[keyword.group]
i = [] i = []
i.extend(inputs[keyword.desc_inputs]) i.extend(self.inputs[keyword.desc_inputs])
i.extend(inputs[keyword.desc_bprop]) i.extend(self.inputs[keyword.desc_bprop])
return { return {
keyword.id: result_id, keyword.id: result_id,
keyword.group: group, keyword.group: group,
keyword.desc_inputs: i, keyword.desc_inputs: i,
keyword.result: function[keyword.block](*i) keyword.result: self.function[keyword.block](*i)
} }

View File

@ -22,6 +22,6 @@ class GroupCartesianProductERPC(IERPolicyComponent):
""" """
Combine expect/result by do cartesian product on group. Combine expect/result by do cartesian product on group.
""" """
def combine(self, expect, result, verification_set): def __call__(self):
ret = [(s1, s2) for s1 in expect for s2 in result if s1[keyword.group] == s2[keyword.group]] ret = [(s1, s2) for s1 in self.expect for s2 in self.result if s1[keyword.group] == s2[keyword.group]]
return ret return ret

View File

@ -22,6 +22,6 @@ class IdCartesianProductERPC(IERPolicyComponent):
""" """
Combine expect/result by do cartesian product on id. Combine expect/result by do cartesian product on id.
""" """
def combine(self, expect, result, verification_set): def __call__(self):
ret = [(s1, s2) for s1 in expect for s2 in result if s1[keyword.id] == s2[keyword.id]] ret = [(s1, s2) for s1 in self.expect for s2 in self.result if s1[keyword.id] == s2[keyword.id]]
return ret return ret

View File

@ -47,9 +47,9 @@ class MeFacadeFC(IFacadeComponent):
} }
}) })
""" """
def adapt(self, verification_set): def __call__(self):
ret = get_block_config() ret = get_block_config()
for config in verification_set: for config in self.verification_set:
tid = config[0] tid = config[0]
group = 'default' group = 'default'
m = config[1] m = config[1]

View File

@ -42,5 +42,5 @@ class CompileBlockBC(IBuilderComponent):
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
""" """
def build_sut(self, verification_set): def __call__(self):
return create_funcs(verification_set, gen_net, compile_block) return create_funcs(self.verification_set, gen_net, compile_block)

View File

@ -43,6 +43,6 @@ class CompileBackwardBlockWrtInputsBC(IBuilderComponent):
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
""" """
def build_sut(self, verification_set): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation('grad', get_all=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, compile_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op)

View File

@ -43,6 +43,6 @@ class CompileBackwardBlockWrtParamsBC(IBuilderComponent):
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
""" """
def build_sut(self, verification_set): def __call__(self, verification_set):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation('grad', get_by_list=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, compile_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op)

View File

@ -25,5 +25,5 @@ class IdentityBC(IBuilderComponent):
Examples: Examples:
'function': Add 'function': Add
""" """
def build_sut(self, verification_set): def __call__(self):
return verification_set[keyword.function] return self.verification_set[keyword.function]

View File

@ -42,5 +42,5 @@ class RunBlockWithRandParamBC(IBuilderComponent):
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
""" """
def build_sut(self, verification_set): def __call__(self):
return create_funcs(verification_set, gen_net, run_block, default_rand_func=get_uniform_with_shape) return create_funcs(self.verification_set, gen_net, run_block, default_rand_func=get_uniform_with_shape)

View File

@ -20,6 +20,6 @@ from ...components.icomponent import IBuilderComponent
from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_uniform_with_shape from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_uniform_with_shape
class RunBackwardBlockWrtInputsWithRandParamBC(IBuilderComponent): class RunBackwardBlockWrtInputsWithRandParamBC(IBuilderComponent):
def build_sut(self, verification_set): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation('grad', get_all=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape)

View File

@ -20,6 +20,6 @@ from ...components.icomponent import IBuilderComponent
from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_uniform_with_shape from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_uniform_with_shape
class RunBackwardBlockWrtParamsWithRandParamBC(IBuilderComponent): class RunBackwardBlockWrtParamsWithRandParamBC(IBuilderComponent):
def build_sut(self, verification_set): def __call__(self):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation('grad', get_by_list=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape)

View File

@ -42,5 +42,5 @@ class RunBlockBC(IBuilderComponent):
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
""" """
def build_sut(self, verification_set): def __call__(self):
return create_funcs(verification_set, gen_net, run_block) return create_funcs(self.verification_set, gen_net, run_block)

View File

@ -20,6 +20,6 @@ from ...components.icomponent import IBuilderComponent
from ...utils.block_util import run_block, gen_grad_net, create_funcs from ...utils.block_util import run_block, gen_grad_net, create_funcs
class RunBackwardBlockWrtInputsBC(IBuilderComponent): class RunBackwardBlockWrtInputsBC(IBuilderComponent):
def build_sut(self, verification_set): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation('grad', get_all=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, run_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op)

View File

@ -20,6 +20,6 @@ from ...components.icomponent import IBuilderComponent
from ...utils.block_util import run_block, gen_grad_net, create_funcs from ...utils.block_util import run_block, gen_grad_net, create_funcs
class RunBackwardBlockWrtParamsBC(IBuilderComponent): class RunBackwardBlockWrtParamsBC(IBuilderComponent):
def build_sut(self, verification_set): def __call__(self):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation('grad', get_by_list=True, sens_param=True)
return create_funcs(verification_set, gen_grad_net, run_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op)

View File

@ -23,7 +23,6 @@ class GroupCartesianProductFIPC(IFIPolicyComponent):
""" """
Combine function/inputs by do cartesian product on group. Combine function/inputs by do cartesian product on group.
""" """
def combine(self, function, inputs, verification_set): def __call__(self):
# pylint: disable=unused-argument ret = [(s1, s2) for s1 in self.function for s2 in self.inputs if s1[keyword.group] == s2[keyword.group]]
ret = [(s1, s2) for s1 in function for s2 in inputs if s1[keyword.group] == s2[keyword.group]]
return ret return ret

View File

@ -22,7 +22,6 @@ class IdCartesianProductFIPC(IFIPolicyComponent):
""" """
Combine function/inputs by do cartesian product on id. Combine function/inputs by do cartesian product on id.
""" """
def combine(self, function, inputs, verification_set): def __call__(self):
# pylint: disable=unused-argument ret = [(s1, s2) for s1 in self.function for s2 in self.inputs if s1[keyword.id] == s2[keyword.id]]
ret = [(s1, s2) for s1 in function for s2 in inputs if s1[keyword.id] == s2[keyword.id]]
return ret return ret

View File

@ -20,28 +20,19 @@ class IComponent:
def __init__(self, verification_set): def __init__(self, verification_set):
self.verification_set = verification_set self.verification_set = verification_set
def run(self): def __call__(self):
raise NotImplementedError raise NotImplementedError
def get_result(self):
return self.result
class IDataComponent(IComponent): class IDataComponent(IComponent):
"""Create inputs for verification_set.""" """Create inputs for verification_set."""
def run(self): def __call__(self):
self.result = self.create_inputs(self.verification_set)
def create_inputs(self, verification_set):
raise NotImplementedError raise NotImplementedError
class IBuilderComponent(IComponent): class IBuilderComponent(IComponent):
"""Build system under test.""" """Build system under test."""
def run(self): def __call__(self):
self.result = self.build_sut(self.verification_set)
def build_sut(self, verification_set):
raise NotImplementedError raise NotImplementedError
@ -52,10 +43,7 @@ class IExectorComponent(IComponent):
self.function = function self.function = function
self.inputs = inputs self.inputs = inputs
def run(self): def __call__(self):
self.result = self.run_function(self.function, self.inputs, self.verification_set)
def run_function(self, function, inputs, verification_set):
raise NotImplementedError raise NotImplementedError
@ -66,10 +54,7 @@ class IVerifierComponent(IComponent):
self.expect = expect self.expect = expect
self.func_result = result self.func_result = result
def run(self): def __call__(self):
self.result = self.verify(self.expect, self.func_result, self.verification_set)
def verify(self, expect, func_result, verification_set):
raise NotImplementedError raise NotImplementedError
@ -80,10 +65,7 @@ class IFIPolicyComponent(IComponent):
self.function = function self.function = function
self.inputs = inputs self.inputs = inputs
def run(self): def __call__(self):
self.result = self.combine(self.function, self.inputs, self.verification_set)
def combine(self, function, inputs, verification_set):
raise NotImplementedError raise NotImplementedError
@ -94,17 +76,11 @@ class IERPolicyComponent(IComponent):
self.expect = expect self.expect = expect
self.result = result self.result = result
def run(self): def __call__(self):
self.result = self.combine(self.expect, self.result, self.verification_set)
def combine(self, expect, result, verification_set):
raise NotImplementedError raise NotImplementedError
class IFacadeComponent(IComponent): class IFacadeComponent(IComponent):
"""Adapt verification_set.""" """Adapt verification_set."""
def run(self): def __call__(self):
self.result = self.adapt(self.verification_set)
def adapt(self, verification_set):
raise NotImplementedError raise NotImplementedError

View File

@ -30,9 +30,9 @@ class GenerateDataSetForLRDC(IDataComponent):
'batch_size': 20, 'batch_size': 20,
} }
""" """
def create_inputs(self, verification_set): def __call__(self):
result = [] result = []
for config in verification_set[keyword.inputs]: for config in self.verification_set[keyword.inputs]:
desc_inputs = config[keyword.desc_inputs] desc_inputs = config[keyword.desc_inputs]
config[keyword.desc_inputs] = generate_dataset_for_linear_regression(desc_inputs[keyword.true_params][0], config[keyword.desc_inputs] = generate_dataset_for_linear_regression(desc_inputs[keyword.true_params][0],
desc_inputs[keyword.true_params][1], desc_inputs[keyword.true_params][1],

View File

@ -41,9 +41,9 @@ class GenerateFromShapeDC(IDataComponent):
([1, 16, 128, 64], np.float32, 6), # (inputs, dtype, scale) ([1, 16, 128, 64], np.float32, 6), # (inputs, dtype, scale)
] ]
""" """
def create_inputs(self, verification_set): def __call__(self):
result = [] result = []
for config in verification_set[keyword.inputs]: for config in self.verification_set[keyword.inputs]:
desc_inputs = config[keyword.desc_inputs] desc_inputs = config[keyword.desc_inputs]
add_fake_input = config.get(keyword.add_fake_input, False) add_fake_input = config.get(keyword.add_fake_input, False)
fake_input_type = config.get(keyword.fake_input_type, np.float32) fake_input_type = config.get(keyword.fake_input_type, np.float32)

View File

@ -26,5 +26,5 @@ class IdentityDC(IDataComponent):
np.array([[2, 2], [2, 2]]).astype(np.float32) np.array([[2, 2], [2, 2]]).astype(np.float32)
] ]
""" """
def create_inputs(self, verification_set): def __call__(self):
return verification_set['inputs'] return self.verification_set['inputs']

View File

@ -43,9 +43,9 @@ class LoadFromNpyDC(IDataComponent):
([2, 2], np.float32, 6) ([2, 2], np.float32, 6)
] ]
""" """
def create_inputs(self, verification_set): def __call__(self):
result = [] result = []
for config in verification_set[keyword.inputs]: for config in self.verification_set[keyword.inputs]:
config[keyword.desc_inputs] = load_data_from_npy_or_shape(config[keyword.desc_inputs]) config[keyword.desc_inputs] = load_data_from_npy_or_shape(config[keyword.desc_inputs])
config[keyword.desc_bprop] = load_data_from_npy_or_shape(config.get(keyword.desc_bprop, [])) config[keyword.desc_bprop] = load_data_from_npy_or_shape(config.get(keyword.desc_bprop, []))
result.append(config) result.append(config)

View File

@ -41,5 +41,5 @@ class CompareWithVC(IVerifierComponent):
'max_error': 1e-3 'max_error': 1e-3
} }
""" """
def verify(self, expect, func_result, verification_set): def __call__(self):
compare(expect, func_result, baseline=keyword.compare_with) compare(self.expect, self.func_result, baseline=keyword.compare_with)

View File

@ -35,5 +35,5 @@ class CompareGradientWithVC(IVerifierComponent):
'max_error': 1e-3 'max_error': 1e-3
} }
""" """
def verify(self, expect, func_result, verification_set): def __call__(self):
compare(expect, func_result, baseline=keyword.compare_gradient_with) compare(self.expect, self.func_result, baseline=keyword.compare_gradient_with)

View File

@ -37,10 +37,10 @@ class LoadFromNpyVC(IVerifierComponent):
([2, 2], np.float32, 6, 1e-3) # (shape, dtype, scale, max_error) ([2, 2], np.float32, 6, 1e-3) # (shape, dtype, scale, max_error)
] ]
""" """
def verify(self, expect, func_result, verification_set): def __call__(self):
dpaths = expect.get(keyword.desc_expect) dpaths = self.expect.get(keyword.desc_expect)
expects = load_data_from_npy_or_shape(dpaths, False) expects = load_data_from_npy_or_shape(dpaths, False)
results = func_result[keyword.result] results = self.func_result[keyword.result]
if results: if results:
results = to_numpy_list(results) results = to_numpy_list(results)
for i, e in enumerate(expects): for i, e in enumerate(expects):

View File

@ -33,9 +33,9 @@ class ShapeTypeVC(IVerifierComponent):
] ]
} }
""" """
def verify(self, expect, func_result, verification_set): def __call__(self):
results = to_numpy_list(func_result[keyword.result]) results = to_numpy_list(self.func_result[keyword.result])
expects = expect[keyword.desc_expect][keyword.shape_type] expects = self.expect[keyword.desc_expect][keyword.shape_type]
for i, e in enumerate(expects): for i, e in enumerate(expects):
if results[i].shape != e[keyword.shape] or results[i].dtype != e[keyword.type]: if results[i].shape != e[keyword.shape] or results[i].dtype != e[keyword.type]:
raise TypeError(f'Error: expect {expect}, but got {func_result}') raise TypeError(f'Error: expect {self.expect}, but got {self.func_result}')

View File

@ -61,14 +61,13 @@ def mindspore_test(verification_pipeline):
for component in facade_components: for component in facade_components:
fc = component(verification_set) fc = component(verification_set)
fc.run() verification_set = fc()
verification_set = fc.get_result()
inputs = [] inputs = []
for component in data_components: for component in data_components:
dc = component(verification_set) dc = component(verification_set)
dc.run() item = dc()
inputs.extend(dc.get_result()) inputs.extend(item)
if not inputs: if not inputs:
logging.warning("Inputs set is empty.") logging.warning("Inputs set is empty.")
@ -76,8 +75,8 @@ def mindspore_test(verification_pipeline):
functions = [] functions = []
for component in builder_components: for component in builder_components:
bc = component(verification_set) bc = component(verification_set)
bc.run() f = bc()
functions.extend(bc.get_result()) functions.extend(f)
if not functions: if not functions:
logging.warning("Function set is empty.") logging.warning("Function set is empty.")
@ -85,8 +84,8 @@ def mindspore_test(verification_pipeline):
fis = [] fis = []
for component in fi_policy_components: for component in fi_policy_components:
fipc = component(verification_set, functions, inputs) fipc = component(verification_set, functions, inputs)
fipc.run() result = fipc()
fis.extend(fipc.get_result()) fis.extend(result)
if not fis: if not fis:
logging.warning("Function inputs pair set is empty.") logging.warning("Function inputs pair set is empty.")
@ -97,8 +96,8 @@ def mindspore_test(verification_pipeline):
results = [] results = []
for component in executor_components: for component in executor_components:
ec = component(verification_set, sut, inputs) ec = component(verification_set, sut, inputs)
ec.run() result = ec()
results.append(ec.get_result()) results.append(result)
if not results: if not results:
logging.warning("Result set is empty.") logging.warning("Result set is empty.")
@ -106,8 +105,8 @@ def mindspore_test(verification_pipeline):
expect_actuals = [] expect_actuals = []
for component in er_policy_components: for component in er_policy_components:
erpc = component(verification_set, verification_set['expect'], results) erpc = component(verification_set, verification_set['expect'], results)
erpc.run() result = erpc()
expect_actuals.extend(erpc.get_result()) expect_actuals.extend(result)
if not expect_actuals: if not expect_actuals:
logging.warning("Expect Result pair set is empty.") logging.warning("Expect Result pair set is empty.")
@ -115,11 +114,11 @@ def mindspore_test(verification_pipeline):
for ea in expect_actuals: for ea in expect_actuals:
for component in verifier_components: for component in verifier_components:
vc = component(verification_set, *ea) vc = component(verification_set, *ea)
vc.run() vc()
def get_tc_name(f, inputs): def get_tc_name(f, inputs):
tc_id = f[keyword.id]+'-'+inputs[keyword.id] tc_id = f[keyword.id] + '-' + inputs[keyword.id]
group = f[keyword.group]+'-'+inputs[keyword.group] group = f[keyword.group] + '-' + inputs[keyword.group]
return 'Group_' + group + '-' + 'Id_' + tc_id return 'Group_' + group + '-' + 'Id_' + tc_id
if fis: if fis: