add parser of case which parameter in tuple in run_op function

This commit is contained in:
guohongzilong 2020-04-29 17:55:12 +08:00
parent 286c1d7767
commit b96df362f8
2 changed files with 15 additions and 0 deletions

View File

@ -329,6 +329,10 @@ def _run_op(obj, op_name, args):
if hasattr(arg, '__parameter__'):
op_inputs.append(arg.default_input)
op_mask[i] = 1
elif isinstance(arg, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args_ = tuple(convert(x) for x in arg)
op_inputs.append(args_)
else:
op_inputs.append(arg)
output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))

View File

@ -16,6 +16,7 @@
import numpy as np
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine
tensor_add = P.TensorAdd()
op_add = P.AddN()
scala_add = Primitive('scalar_add')
add = C.MultitypeFuncGraph('add')
@ -50,5 +52,14 @@ def test_multitype_tensor():
mainf(tensor1, tensor2)
@non_graph_engine
def test_multitype_tuple():
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
params1 = Parameter(tensor1, name="params1")
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
output = op_add((params1, tensor2))
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
def test_multitype_scalar():
mainf(1, 2)