forked from OSSInnovation/mindspore
!3381 fix mix precision operator issue
Merge pull request !3381 from wangqiuliang/fix-mix-precsion-operator-issue
This commit is contained in:
commit
e09c61d6a6
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""builtin_operations"""
|
||||
import numpy as np
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
||||
|
@ -171,3 +172,12 @@ def tuple_to_array(x):
|
|||
def stop_gradient(x):
|
||||
"""Implement `stop_gradient`."""
|
||||
return x
|
||||
|
||||
def mixed_precision_cast(dst_type, x):
|
||||
"""Implement `mixed_precision_cast`."""
|
||||
if isinstance(x, tuple):
|
||||
res = list()
|
||||
for item in x:
|
||||
res.append(F.cast(item, dst_type))
|
||||
return tuple(res)
|
||||
return F.cast(x, dst_type)
|
||||
|
|
|
@ -61,7 +61,7 @@ struct OpExecInfo {
|
|||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);
|
||||
|
||||
const std::set<std::string> ignore_infer_prim = {"make_ref"};
|
||||
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;
|
|||
|
||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||
// primitive unable to infer value for constant input in PyNative mode
|
||||
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
|
||||
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
|||
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
||||
auto cell_id = GetId(cell);
|
||||
if (cell_graph_map_.count(cell_id) != 0) {
|
||||
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
|
||||
resource_ = cell_resource_map_[cell_id];
|
||||
}
|
||||
MS_LOG(DEBUG) << "Newgraph already compiled";
|
||||
return;
|
||||
}
|
||||
|
@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
|
||||
if (top_g_ == nullptr) {
|
||||
top_g_ = curr_g_ = g;
|
||||
resource_ = std::make_shared<pipeline::Resource>();
|
||||
cell_resource_map_[cell_id] = resource_;
|
||||
df_builder_ = std::make_shared<FuncGraph>();
|
||||
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
|
||||
Pushp();
|
||||
|
@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
MS_LOG(DEBUG) << "Clear res";
|
||||
(void)graph_map_.erase(flag);
|
||||
(void)cell_graph_map_.erase(flag);
|
||||
(void)cell_resource_map_.erase(flag);
|
||||
Clean();
|
||||
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
|
||||
MS_LOG(DEBUG) << "Clear";
|
||||
top_g_ = nullptr;
|
||||
df_builder_ = nullptr;
|
||||
curr_g_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_id_map_.clear();
|
||||
|
@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
|
|||
Clear();
|
||||
grad_flag_ = false;
|
||||
op_forward_map_.clear();
|
||||
df_builder_ = nullptr;
|
||||
ad::CleanRes();
|
||||
pipeline::ReclaimOptimizer();
|
||||
}
|
||||
|
|
|
@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool grad_flag_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
|
||||
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
|
||||
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||
std::unordered_map<std::string, ValuePtr> op_forward_map_;
|
||||
std::unordered_map<std::string, size_t> op_id_map_;
|
||||
|
|
|
@ -240,12 +240,13 @@ class Cell:
|
|||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
cast_inputs = list()
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float16))
|
||||
if self._mindspore_flags.get('fp32'):
|
||||
for item in inputs:
|
||||
cast_inputs.append(cast(item, mstype.float32))
|
||||
if cast_inputs:
|
||||
cast_inputs = tuple(cast_inputs)
|
||||
else:
|
||||
|
@ -496,10 +497,11 @@ class Cell:
|
|||
Args:
|
||||
param (Parameter): The parameter to cast.
|
||||
"""
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
|
||||
return cast(param, mstype.float16)
|
||||
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
|
||||
return cast(param, mstype.float32)
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
return cast(param, mstype.float16)
|
||||
if self._mindspore_flags.get('fp32'):
|
||||
return cast(param, mstype.float32)
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
|
|
|
@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore as ms
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break():
|
|||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
|
||||
|
||||
def test_mixed_precision_cast():
|
||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
z = F.mixed_precision_cast(mstype.float16, x)
|
||||
assert z.dtype == mstype.float16
|
||||
|
|
Loading…
Reference in New Issue