!18488 Remove Parameter Check in PyNative mode

Merge pull request !18488 from caifubi/master-pynative-no-parameter
This commit is contained in:
i-robot 2021-07-05 12:05:46 +00:00 committed by Gitee
commit ea440db85e
2 changed files with 2 additions and 10 deletions

View File

@ -1091,14 +1091,9 @@ void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map
if (!signature.empty()) {
sig = signature[i].rw;
}
bool is_parameter = false;
TypeId arg_type_id = kTypeUnknown;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
if (arg->is_parameter()) {
is_parameter = true;
MS_LOG(DEBUG) << "Parameter is read " << i;
}
arg_type_id = arg->data_type();
}
// implicit cast
@ -1107,9 +1102,6 @@ void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
}
if (sig == SignatureEnumRW::kRWWrite) {
if (!is_parameter) {
prim::RaiseExceptionForCheckParameter(prim->name(), i, "not");
}
if (arg_type_id != kTypeUnknown) {
if (!is_same_type) {
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),

View File

@ -283,6 +283,6 @@ def test_assign_check_in_sig():
net = AssignCheck()
x = Tensor(2, ms.int8)
y = Tensor(3, ms.uint8)
with pytest.raises(TypeError) as e:
with pytest.raises(RuntimeError) as e:
net(x, y)
assert "Parameter" in e.value.args[0]
assert "can not cast automatically" in e.value.args[0]