!29382 Handle monad inputs in signatures.

Merge pull request !29382 from huangbingjian/update_signature
This commit is contained in:
i-robot 2022-01-26 03:04:04 +00:00 committed by Gitee
commit 80bb043360
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 15 additions and 7 deletions

View File

@ -247,15 +247,25 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
}
}
void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
const std::string &func_name) {
void CheckSigSize(const ValuePtr &function, const size_t &sig_size, const bool &has_var,
const AbstractBasePtrList &args_spec_list, const std::string &func_name) {
if (sig_size > 0) {
if (has_var) {
if (sig_size - 1 > args_spec_list.size()) {
MS_LOG(EXCEPTION) << "Function " << func_name
<< "'s input length less than PositionalKeyword Signature length.";
}
} else if (args_spec_list.size() > sig_size) {
return;
}
// Consider the case where there are monads in primitive's args_spec_list.
size_t args_size = args_spec_list.size();
if (function->isa<Primitive>()) {
auto prim = function->cast<PrimitivePyPtr>();
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
args_size -= GetAbstractMonadNum(args_spec_list);
}
}
if (args_size > sig_size) {
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
}
}
@ -279,7 +289,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
auto &signature = GetSignature(function);
std::size_t sig_size = signature.size();
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
CheckSigSize(sig_size, has_var, args_spec_list, func_name);
CheckSigSize(function, sig_size, has_var, args_spec_list, func_name);
std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indices;
std::vector<TypePtr> input_types;

View File

@ -15,7 +15,6 @@
"""Other operators."""
import functools
import mindspore.common._monad as monad
from mindspore import log as logger
from .. import signature as sig
from ..._checkparam import Validator as validator, Rel
@ -59,8 +58,7 @@ class Assign(Primitive):
"""
__mindspore_signature__ = (
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T),
sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1)
sig.make_sig('value', dtype=sig.sig_dtype.T)
)
@prim_attr_register