!3738 fix bug of cast dtype when using mix_presion in pynative mode
Merge pull request !3738 from jinyaohui/master
This commit is contained in:
commit
2883f9366d
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
||||
|
||||
|
@ -115,6 +116,7 @@ def bool_or(x, y):
|
|||
"""Implement `bool_or`."""
|
||||
return x or y
|
||||
|
||||
|
||||
def vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
|
@ -143,10 +145,12 @@ def list_len(x):
|
|||
"""Implement `list_len`."""
|
||||
return len(x)
|
||||
|
||||
|
||||
def Depend(value, expr):
|
||||
"""Implement `Depend`."""
|
||||
return value
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def make_ref(key, value, ref):
|
||||
return value
|
||||
|
@ -177,8 +181,12 @@ def stop_gradient(x):
|
|||
|
||||
hyper_map = C.HyperMap()
|
||||
|
||||
|
||||
def mixed_precision_cast(dst_type, x):
|
||||
"""Implement `mixed_precision_cast`."""
|
||||
def cast_inner(data):
|
||||
return F.cast(data, dst_type)
|
||||
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
|
||||
return F.cast(data, dst_type)
|
||||
return data
|
||||
|
||||
return hyper_map(cast_inner, x)
|
||||
|
|
|
@ -745,13 +745,16 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
|||
return err_ret;
|
||||
}
|
||||
|
||||
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
|
||||
if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) {
|
||||
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
|
||||
}
|
||||
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
|
||||
MS_LOG(DEBUG) << "RunOp end";
|
||||
}
|
||||
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
|
||||
MS_LOG(DEBUG) << "RunOp end";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue