diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index cc508cfaf5..c6eb673624 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -257,41 +257,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { return RunOp(args)[0]; } -void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) { - auto &out_args = op_exec_info->op_inputs; - auto signature = prim->signatures(); - std::vector dtypes; - (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature &sig) { return sig.dtype; }); - int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { - return; +py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { + auto tensor = py::cast(obj); + auto cast_type = tensor->cast_dtype(); + py::object cast_output; + if (cast_type != nullptr) { + auto source_element = tensor->Dtype(); + if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { + MS_LOG(DEBUG) << "cast to " << cast_type->ToString(); + cast_output = DoAutoCast(obj, cast_type->type_id()); + *is_cast = true; + } } - auto type_indexes = GetTypeIndex(dtypes); - auto dst_type = GetDstType(out_args, type_indexes); + return cast_output; +} - for (size_t i = 0; i < dtypes.size(); ++i) { - if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { +py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { + auto tuple_size = static_cast(tuple.size()); + py::tuple result(tuple_size); + + for (int i = 0; i < tuple_size; i++) { + if (py::isinstance(tuple[i])) { + MS_LOG(DEBUG) << "call cast for item " << i; + result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); + } else if (py::isinstance(tuple[i])) { + result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); + } + } + return result; +} + +bool GetSignatureType(const PrimitivePyPtr &prim, std::vector *dtypes) { + auto signature = prim->signatures(); + bool has_sig_dtype = false; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), + [&has_sig_dtype](const Signature &sig) { + auto dtype = sig.dtype; + if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) { + has_sig_dtype = true; + } + return dtype; + }); + return has_sig_dtype; +} + +void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, + const std::vector &dtypes, const OpExecInfoPtr &op_exec_info) { + const auto &signature = prim->signatures(); + auto &out_args = op_exec_info->op_inputs; + bool has_dtype_sig = (dtypes.size() > 0); + for (size_t i = 0; i < out_args.size(); ++i) { + MS_LOG(DEBUG) << "check inputs " << i; + auto obj = out_args[i]; + auto sig = SignatureEnumRW::kRWDefault; + if (signature.size() > 0) { + sig = signature[i].rw; + } + bool is_parameter = false; + TypeId arg_type_id = kTypeUnknown; + if (py::isinstance(obj)) { + auto arg = py::cast(obj); + if (arg->is_parameter()) { + is_parameter = true; + MS_LOG(DEBUG) << "parameter is read " << i; + } + arg_type_id = arg->data_type(); + } + + // No need to implicit cast if no dtype. + if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { continue; } auto it = dst_type.find(dtypes[i]); if (it == dst_type.end() || it->second == kTypeUnknown) { continue; } - - auto obj = out_args[i]; - auto sig = signature[i].rw; - bool is_parameter = false; + // implicit cast bool is_same_type = false; - TypeId arg_type_id = kTypeUnknown; bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); - if (py::isinstance(obj)) { - auto arg = py::cast(obj); - if (arg->is_parameter()) { - is_parameter = true; - } - arg_type_id = arg->data_type(); - } if (arg_type_id != 0) { is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); } @@ -318,7 +361,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe } py::object cast_output = DoAutoCast(out_args[i], it->second); out_args[i] = cast_output; - ValuePtr input_value = PyAttrValue(cast_output); } } @@ -347,7 +389,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_inputs = args[PY_INPUTS]; - ConvertInputs(prim, args[PY_INPUTS], op_exec_info); return op_exec_info; } @@ -698,11 +739,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v inputs.push_back(NewValueNode(prim)); size_t size = op_exec_info->op_inputs.size(); - auto const_input_index = prim->get_const_input_indexes(); - bool have_const_input = !const_input_index.empty(); - bool is_const_prim = prim->is_const_prim(); + // ignore signature for cast op + bool is_cast_op = (op_exec_info->op_name == "Cast"); + if (!is_cast_op) { + const auto &signature = prim->signatures(); + for (size_t i = 0; i < size; i++) { + auto obj = op_exec_info->op_inputs[i]; + auto sig = SignatureEnumRW::kRWDefault; + if (signature.size() > 0) { + sig = signature[i].rw; + } + MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " + << std::string(py::repr(obj)); + // mix precision for non param + bool is_cast = false; + py::object cast_output; + if (py::isinstance(obj)) { + auto meta_tensor = obj.cast(); + if (meta_tensor && meta_tensor->is_parameter()) { + if (sig != SignatureEnumRW::kRWRead) { + continue; + } + } + // redundant cast call if the tensor is a const Tensor. + cast_output = DoParamMixPrecisionCast(&is_cast, obj); + } else if (py::isinstance(obj)) { + // mix precision for tuple inputs + cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); + } + if (is_cast) { + op_exec_info->op_inputs[i] = cast_output; + } + } + std::vector dtypes; + + bool has_dtype_sig = GetSignatureType(prim, &dtypes); + std::map dst_types; + if (has_dtype_sig) { + // fetch info for implicit cast + auto type_indexes = GetTypeIndex(dtypes); + dst_types = GetDstType(op_exec_info->op_inputs, type_indexes); + } + MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name; + DoSignatrueCast(prim, dst_types, dtypes, op_exec_info); + } + MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name; for (size_t i = 0; i < size; i++) { - auto obj = op_exec_info->op_inputs[i]; + const auto &obj = op_exec_info->op_inputs[i]; bool op_mask = false; if (py::isinstance(obj)) { auto meta_tensor = obj.cast(); @@ -710,9 +793,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v op_mask = meta_tensor->is_parameter(); } } - (*op_masks).push_back(op_mask); - MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " + MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " << grad_flag_; AnfNodePtr node = nullptr; @@ -727,6 +809,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v if (node != nullptr && node->abstract() != nullptr) { abs = node->abstract(); } + + auto const_input_index = prim->get_const_input_indexes(); + bool have_const_input = !const_input_index.empty(); + bool is_const_prim = prim->is_const_prim(); MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " << prim->is_const_prim(); bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i); @@ -998,7 +1084,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { } py::tuple PynativeExecutor::RunOpInner(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start" << args.size(); + MS_LOG(DEBUG) << "RunOp start " << args.size(); OpExecInfoPtr op_exec_info = nullptr; auto prim = py::cast(args[PY_PRIM]); auto name = py::cast(args[PY_NAME]); diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 13ba6314b5..f9636aacb9 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -455,6 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.set_dtype(mindspore.int32) mindspore.int32 )mydelimiter") + .def("set_cast_dtype", &Tensor::set_cast_dtype) .def("__str__", &Tensor::ToString) .def("__repr__", &Tensor::ToStringRepr) .def(py::pickle( diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 279a93dc8a..e099ed18d5 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -292,6 +292,7 @@ class _PynativeExecutor: def __init__(self): self._executor = PynativeExecutor_.get_instance() + #TODO(kpy):add a type arg def new_graph(self, obj, *args, **kwargs): self._executor.new_graph(obj, *args, *(kwargs.values())) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 1213f6acb8..f5cae3fa67 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -219,16 +219,6 @@ class Parameter(MetaTensor): raise ValueError("The type of the name should be `str` or `None`.") self._param_info.name = name_ - @property - def cast_type(self): - return self._cast_type - - @cast_type.setter - def cast_type(self, dst_type): - if dst_type not in (mstype.float16, mstype.float32, None): - raise ValueError("The type of the name should be type of [float32, float16] or `None`.") - self._cast_type = dst_type - @property def sliced(self): """Get slice status of the parameter.""" diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 13472b3f0c..8689980268 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -268,6 +268,8 @@ class Tensor : public MetaTensor { std::vector padding_type() const { return padding_type_; } std::string id() const { return id_; } + TypePtr cast_dtype() { return cast_dtype_; } + void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; } void SetNeedWait(bool need_wait) { if (event_ != nullptr) { @@ -310,6 +312,7 @@ class Tensor : public MetaTensor { mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; DeviceSyncPtr device_sync_{nullptr}; std::vector padding_type_; + TypePtr cast_dtype_{nullptr}; }; using TensorPtr = std::shared_ptr; using TensorPtrList = std::vector>; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 8d1bc884a8..76f7f13a5c 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -61,15 +61,17 @@ class Cell(Cell_): """ IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', - '_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode', - '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced', - 'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] + '_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase', + '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', + '_attr_synced', 'enable_hook', 'pynative', 'requires_grad', + '_auto_parallel_compile_and_run', 'cell_type'] def __init__(self, auto_prefix=True, flags=None): Cell_.__init__(self, self._cell_tag) self._params = OrderedDict() self._cells = OrderedDict() self._params_list = OrderedDict() + self._tensor_list = OrderedDict() self.training = False self.requires_grad = False self.pynative = False @@ -228,6 +230,9 @@ class Cell(Cell_): return cells[name] if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__: params_list = self.__dict__['_params_list'] + tensor_list = self.__dict__['_tensor_list'] + if name in tensor_list: + return self.cast_param(tensor_list[name]) if name in params_list: para_list = params_list[name] cast_list = list() @@ -339,6 +344,7 @@ class Cell(Cell_): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') params_list = self.__dict__.get('_params_list') + tensor_list = self.__dict__.get('_tensor_list') if isinstance(value, Parameter): if params is None: raise AttributeError("Can not assign params before Cell.__init__() call.") @@ -383,6 +389,13 @@ class Cell(Cell_): if value is not None: raise TypeError("Expected type is cell, but got {}.".format(type(value))) self._cells[name] = None + elif isinstance(value, Tensor): + if context.get_context("mode") == context.PYNATIVE_MODE: + if name in self.__dict__: + del self.__dict__[name] + tensor_list[name] = value + else: + object.__setattr__(self, name, value) else: if isinstance(value, Primitive): value.set_prim_instance_name(name) @@ -570,11 +583,9 @@ class Cell(Cell_): """ if hasattr(self, "_mindspore_flags"): if self._mindspore_flags.get('fp16'): - param.cast_type = mstype.float16 - elif self._mindspore_flags.get('fp32'): - param.cast_type = mstype.float32 - else: - param.cast_type = None + param.set_cast_dtype(mstype.float16) + if self._mindspore_flags.get('fp32'): + param.set_cast_dtype(mstype.float32) return param def insert_child_to_cell(self, child_name, child): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 1b470b36fd..25e05749f6 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -17,7 +17,6 @@ import inspect import copy from mindspore.common.api import _wrap_func -from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type from . import signature as sig @@ -496,16 +495,7 @@ def constexpr(fn=None, get_instance=True, name=None): @_wrap_func def _run_op(obj, op_name, args): """Single op execution function supported by ge in PyNative mode.""" - cast = tensor_operator_registry.get("cast") - if op_name == "Cast" or obj.update_parameter: - cast_args = args - else: - cast_args = list(args) - for idx, arg in enumerate(args): - cast_type = getattr(arg, "cast_type", None) - if cast_type: - cast_args[idx] = cast(arg, cast_type) - output = real_run_op(obj, op_name, cast_args) + output = real_run_op(obj, op_name, args) if not output: raise RuntimeError("Pynative run op %s failed!" % op_name) if len(output) == 1: diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index d311f0b40b..722ce1c39c 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -74,7 +74,7 @@ def test_add_cast_flag(): net.fc3.to_float(mstype.float32) net = train_step_with_loss_warp(net) net.set_train() - _executor.compile(net, predict, label) + net(predict, label) def test_add_cast_flag_tensor(): @@ -82,7 +82,7 @@ def test_add_cast_flag_tensor(): net = NetForConcat() net.add_flags_recursive(fp16=True) net.set_train() - _executor.compile(net, x1) + net(x1) def test_on_momentum(): @@ -91,7 +91,7 @@ def test_on_momentum(): net = LeNet5() net = train_step_with_loss_warp(net).to_float(mstype.float16) net.set_train() - _executor.compile(net, predict, label) + net(predict, label) def test_data_parallel_with_cast(): @@ -134,7 +134,6 @@ def test_nn_prelu(): class NetForCast(nn.Cell): def __init__(self): super(NetForCast, self).__init__() - self.concat = P.Concat() self.x1 = Tensor(1.0, mstype.float32) self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2') @@ -144,11 +143,10 @@ class NetForCast(nn.Cell): def test_cast(): - context.set_context(save_graphs=True) x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01) net = NetForCast() net.add_flags_recursive(fp16=True) - _executor.compile(net, x) + net(x) class IRBlockZ(nn.Cell):