!3738 fix bug of cast dtype when using mix_presion in pynative mode

Merge pull request !3738 from jinyaohui/master
This commit is contained in:
mindspore-ci-bot 2020-08-02 23:29:51 +08:00 committed by Gitee
commit 2883f9366d
2 changed files with 18 additions and 7 deletions

View File

@ -17,6 +17,7 @@ import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
@ -115,6 +116,7 @@ def bool_or(x, y):
"""Implement `bool_or`."""
return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
@ -143,10 +145,12 @@ def list_len(x):
"""Implement `list_len`."""
return len(x)
def Depend(value, expr):
"""Implement `Depend`."""
return value
# only used in PyNative mode
def make_ref(key, value, ref):
return value
@ -177,8 +181,12 @@ def stop_gradient(x):
hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
def cast_inner(data):
return F.cast(data, dst_type)
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
return F.cast(data, dst_type)
return data
return hyper_map(cast_inner, x)

View File

@ -745,13 +745,16 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return err_ret;
}
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) {
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
}
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end";
}
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end";
return result;
}