fix mix precesion operator issue

This commit is contained in:
kingfo 2020-07-23 20:27:01 +08:00
parent 0a2980ca74
commit 73ea9b7855
7 changed files with 40 additions and 13 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
@ -171,3 +172,12 @@ def tuple_to_array(x):
def stop_gradient(x):
"""Implement `stop_gradient`."""
return x
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)

View File

@ -61,7 +61,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);
const std::set<std::string> ignore_infer_prim = {"make_ref"};
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative
} // namespace mindspore

View File

@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
namespace mindspore {
namespace pynative {
@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
}
MS_LOG(DEBUG) << "Newgraph already compiled";
return;
}
@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
if (top_g_ == nullptr) {
top_g_ = curr_g_ = g;
resource_ = std::make_shared<pipeline::Resource>();
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
Pushp();
@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG(DEBUG) << "Clear res";
(void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag);
(void)cell_resource_map_.erase(flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG(DEBUG) << "Clear";
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
Clear();
grad_flag_ = false;
op_forward_map_.clear();
df_builder_ = nullptr;
ad::CleanRes();
pipeline::ReclaimOptimizer();
}

View File

@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;

View File

@ -240,12 +240,13 @@ class Cell:
else:
_pynative_exec.set_grad_flag(False)
cast_inputs = list()
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
@ -496,10 +497,11 @@ class Cell:
Args:
param (Parameter): The parameter to cast.
"""
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
return param
def insert_child_to_cell(self, child_name, child):

View File

@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer

View File

@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break():
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
assert z.dtype == mstype.float16