diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 133fb4732c2..c980da5c66a 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -247,15 +247,25 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign } } -void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list, - const std::string &func_name) { +void CheckSigSize(const ValuePtr &function, const size_t &sig_size, const bool &has_var, + const AbstractBasePtrList &args_spec_list, const std::string &func_name) { if (sig_size > 0) { if (has_var) { if (sig_size - 1 > args_spec_list.size()) { MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length less than PositionalKeyword Signature length."; } - } else if (args_spec_list.size() > sig_size) { + return; + } + // Consider the case where there are monads in primitive's args_spec_list. + size_t args_size = args_spec_list.size(); + if (function->isa()) { + auto prim = function->cast(); + if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) { + args_size -= GetAbstractMonadNum(args_spec_list); + } + } + if (args_size > sig_size) { MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; } } @@ -279,7 +289,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func auto &signature = GetSignature(function); std::size_t sig_size = signature.size(); auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); - CheckSigSize(sig_size, has_var, args_spec_list, func_name); + CheckSigSize(function, sig_size, has_var, args_spec_list, func_name); std::vector op_inputs; std::set write_indices; std::vector input_types; diff --git a/mindspore/python/mindspore/ops/operations/other_ops.py b/mindspore/python/mindspore/ops/operations/other_ops.py index 4503b0d23dc..00c77d91627 100644 --- a/mindspore/python/mindspore/ops/operations/other_ops.py +++ b/mindspore/python/mindspore/ops/operations/other_ops.py @@ -15,7 +15,6 @@ """Other operators.""" import functools -import mindspore.common._monad as monad from mindspore import log as logger from .. import signature as sig from ..._checkparam import Validator as validator, Rel @@ -59,8 +58,7 @@ class Assign(Primitive): """ __mindspore_signature__ = ( sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), - sig.make_sig('value', dtype=sig.sig_dtype.T), - sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1) + sig.make_sig('value', dtype=sig.sig_dtype.T) ) @prim_attr_register