diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 6bd382c1b6..780b5fe367 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -14,6 +14,7 @@ # ============================================================================ """builtin_operations""" import numpy as np +from mindspore.ops import functional as F from mindspore.common.tensor import Tensor from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype @@ -171,3 +172,12 @@ def tuple_to_array(x): def stop_gradient(x): """Implement `stop_gradient`.""" return x + +def mixed_precision_cast(dst_type, x): + """Implement `mixed_precision_cast`.""" + if isinstance(x, tuple): + res = list() + for item in x: + res.append(F.cast(item, dst_type)) + return tuple(res) + return F.cast(x, dst_type) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index efde3f2e58..2739b6036e 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -61,7 +61,7 @@ struct OpExecInfo { using OpExecInfoPtr = std::shared_ptr; OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); -const std::set ignore_infer_prim = {"make_ref"}; +const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 224d10c214..e77cfb233d 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy; const char SINGLE_OP_GRAPH[] = "single_op_graph"; // primitive unable to infer value for constant input in PyNative mode -const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; +const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"}; namespace mindspore { namespace pynative { @@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetId(cell); if (cell_graph_map_.count(cell_id) != 0) { + if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { + resource_ = cell_resource_map_[cell_id]; + } MS_LOG(DEBUG) << "Newgraph already compiled"; return; } @@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (top_g_ == nullptr) { top_g_ = curr_g_ = g; + resource_ = std::make_shared(); + cell_resource_map_[cell_id] = resource_; df_builder_ = std::make_shared(); MS_LOG(DEBUG) << "First new graph" << top_g_.get(); Pushp(); @@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) { MS_LOG(DEBUG) << "Clear res"; (void)graph_map_.erase(flag); (void)cell_graph_map_.erase(flag); + (void)cell_resource_map_.erase(flag); Clean(); // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); @@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) { MS_LOG(DEBUG) << "Clear"; top_g_ = nullptr; + df_builder_ = nullptr; curr_g_ = nullptr; graph_info_map_.clear(); op_id_map_.clear(); @@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() { Clear(); grad_flag_ = false; op_forward_map_.clear(); - df_builder_ = nullptr; ad::CleanRes(); pipeline::ReclaimOptimizer(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 1a5cb7408b..4b940246ef 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this { bool grad_flag_; std::unordered_map graph_map_; std::unordered_map cell_graph_map_; + std::unordered_map cell_resource_map_; std::unordered_map graph_info_map_; std::unordered_map op_forward_map_; std::unordered_map op_id_map_; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 723fa21300..5e7e2017f5 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -240,12 +240,13 @@ class Cell: else: _pynative_exec.set_grad_flag(False) cast_inputs = list() - if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): - for item in inputs: - cast_inputs.append(cast(item, mstype.float16)) - if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): - for item in inputs: - cast_inputs.append(cast(item, mstype.float32)) + if hasattr(self, "_mindspore_flags"): + if self._mindspore_flags.get('fp16'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float16)) + if self._mindspore_flags.get('fp32'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float32)) if cast_inputs: cast_inputs = tuple(cast_inputs) else: @@ -496,10 +497,11 @@ class Cell: Args: param (Parameter): The parameter to cast. """ - if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): - return cast(param, mstype.float16) - if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): - return cast(param, mstype.float32) + if hasattr(self, "_mindspore_flags"): + if self._mindspore_flags.get('fp16'): + return cast(param, mstype.float16) + if self._mindspore_flags.get('fp32'): + return cast(param, mstype.float32) return param def insert_child_to_cell(self, child_name, child): diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 31dfb9f9ed..9cda88b26b 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell): def __init__(self, network, optimizer, scale_update_cell=None): super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = optimizer.parameters self.optimizer = optimizer diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 9de4a9c82f..13b0aa9ce3 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -20,6 +20,7 @@ import mindspore as ms from mindspore import Tensor from mindspore import context from mindspore import nn +from mindspore.common import dtype as mstype from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break(): t = Tensor(np.ones([2, 3], dtype=np.float32)) net = Net() net(t) + + +def test_mixed_precision_cast(): + x = Tensor(np.ones([2, 3], dtype=np.float32)) + z = F.mixed_precision_cast(mstype.float16, x) + assert z.dtype == mstype.float16