!29382 Handle monad inputs in signatures.
Merge pull request !29382 from huangbingjian/update_signature
This commit is contained in:
commit
80bb043360
|
@ -247,15 +247,25 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|||
}
|
||||
}
|
||||
|
||||
void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
|
||||
const std::string &func_name) {
|
||||
void CheckSigSize(const ValuePtr &function, const size_t &sig_size, const bool &has_var,
|
||||
const AbstractBasePtrList &args_spec_list, const std::string &func_name) {
|
||||
if (sig_size > 0) {
|
||||
if (has_var) {
|
||||
if (sig_size - 1 > args_spec_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Function " << func_name
|
||||
<< "'s input length less than PositionalKeyword Signature length.";
|
||||
}
|
||||
} else if (args_spec_list.size() > sig_size) {
|
||||
return;
|
||||
}
|
||||
// Consider the case where there are monads in primitive's args_spec_list.
|
||||
size_t args_size = args_spec_list.size();
|
||||
if (function->isa<Primitive>()) {
|
||||
auto prim = function->cast<PrimitivePyPtr>();
|
||||
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
|
||||
args_size -= GetAbstractMonadNum(args_spec_list);
|
||||
}
|
||||
}
|
||||
if (args_size > sig_size) {
|
||||
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
|
||||
}
|
||||
}
|
||||
|
@ -279,7 +289,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
auto &signature = GetSignature(function);
|
||||
std::size_t sig_size = signature.size();
|
||||
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
|
||||
CheckSigSize(sig_size, has_var, args_spec_list, func_name);
|
||||
CheckSigSize(function, sig_size, has_var, args_spec_list, func_name);
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
std::set<size_t> write_indices;
|
||||
std::vector<TypePtr> input_types;
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
"""Other operators."""
|
||||
import functools
|
||||
import mindspore.common._monad as monad
|
||||
from mindspore import log as logger
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
|
@ -59,8 +58,7 @@ class Assign(Primitive):
|
|||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('value', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1)
|
||||
sig.make_sig('value', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue