forked from mindspore-Ecosystem/mindspore
add parser of case which parameter in tuple in run_op function
This commit is contained in:
parent
286c1d7767
commit
b96df362f8
|
@ -329,6 +329,10 @@ def _run_op(obj, op_name, args):
|
|||
if hasattr(arg, '__parameter__'):
|
||||
op_inputs.append(arg.default_input)
|
||||
op_mask[i] = 1
|
||||
elif isinstance(arg, tuple):
|
||||
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
|
||||
args_ = tuple(convert(x) for x in arg)
|
||||
op_inputs.append(args_)
|
||||
else:
|
||||
op_inputs.append(arg)
|
||||
output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine
|
|||
|
||||
|
||||
tensor_add = P.TensorAdd()
|
||||
op_add = P.AddN()
|
||||
scala_add = Primitive('scalar_add')
|
||||
add = C.MultitypeFuncGraph('add')
|
||||
|
||||
|
@ -50,5 +52,14 @@ def test_multitype_tensor():
|
|||
mainf(tensor1, tensor2)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_multitype_tuple():
|
||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
params1 = Parameter(tensor1, name="params1")
|
||||
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
output = op_add((params1, tensor2))
|
||||
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
|
||||
|
||||
|
||||
def test_multitype_scalar():
|
||||
mainf(1, 2)
|
||||
|
|
Loading…
Reference in New Issue