forked from OSSInnovation/mindspore
!6320 change mix_precision to c++
Merge pull request !6320 from vlne-v1/pynative_amp
This commit is contained in:
commit
3aa07a4362
|
@ -257,41 +257,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
|||
return RunOp(args)[0];
|
||||
}
|
||||
|
||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
|
||||
auto &out_args = op_exec_info->op_inputs;
|
||||
auto signature = prim->signatures();
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||
return;
|
||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
||||
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
||||
auto cast_type = tensor->cast_dtype();
|
||||
py::object cast_output;
|
||||
if (cast_type != nullptr) {
|
||||
auto source_element = tensor->Dtype();
|
||||
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
||||
MS_LOG(DEBUG) << "cast to " << cast_type->ToString();
|
||||
cast_output = DoAutoCast(obj, cast_type->type_id());
|
||||
*is_cast = true;
|
||||
}
|
||||
auto type_indexes = GetTypeIndex(dtypes);
|
||||
auto dst_type = GetDstType(out_args, type_indexes);
|
||||
}
|
||||
return cast_output;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
||||
auto tuple_size = static_cast<int>(tuple.size());
|
||||
py::tuple result(tuple_size);
|
||||
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
|
||||
MS_LOG(DEBUG) << "call cast for item " << i;
|
||||
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
||||
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
||||
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
|
||||
auto signature = prim->signatures();
|
||||
bool has_sig_dtype = false;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
|
||||
[&has_sig_dtype](const Signature &sig) {
|
||||
auto dtype = sig.dtype;
|
||||
if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||
has_sig_dtype = true;
|
||||
}
|
||||
return dtype;
|
||||
});
|
||||
return has_sig_dtype;
|
||||
}
|
||||
|
||||
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
||||
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) {
|
||||
const auto &signature = prim->signatures();
|
||||
auto &out_args = op_exec_info->op_inputs;
|
||||
bool has_dtype_sig = (dtypes.size() > 0);
|
||||
for (size_t i = 0; i < out_args.size(); ++i) {
|
||||
MS_LOG(DEBUG) << "check inputs " << i;
|
||||
auto obj = out_args[i];
|
||||
auto sig = SignatureEnumRW::kRWDefault;
|
||||
if (signature.size() > 0) {
|
||||
sig = signature[i].rw;
|
||||
}
|
||||
bool is_parameter = false;
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
|
||||
if (arg->is_parameter()) {
|
||||
is_parameter = true;
|
||||
MS_LOG(DEBUG) << "parameter is read " << i;
|
||||
}
|
||||
arg_type_id = arg->data_type();
|
||||
}
|
||||
|
||||
// No need to implicit cast if no dtype.
|
||||
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||
continue;
|
||||
}
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto obj = out_args[i];
|
||||
auto sig = signature[i].rw;
|
||||
bool is_parameter = false;
|
||||
// implicit cast
|
||||
bool is_same_type = false;
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
|
||||
if (arg->is_parameter()) {
|
||||
is_parameter = true;
|
||||
}
|
||||
arg_type_id = arg->data_type();
|
||||
}
|
||||
if (arg_type_id != 0) {
|
||||
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
|
||||
}
|
||||
|
@ -318,7 +361,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
|
|||
}
|
||||
py::object cast_output = DoAutoCast(out_args[i], it->second);
|
||||
out_args[i] = cast_output;
|
||||
ValuePtr input_value = PyAttrValue(cast_output);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -347,7 +389,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
|
||||
return op_exec_info;
|
||||
}
|
||||
|
||||
|
@ -698,11 +739,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
inputs.push_back(NewValueNode(prim));
|
||||
|
||||
size_t size = op_exec_info->op_inputs.size();
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
// ignore signature for cast op
|
||||
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
||||
if (!is_cast_op) {
|
||||
const auto &signature = prim->signatures();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto obj = op_exec_info->op_inputs[i];
|
||||
auto sig = SignatureEnumRW::kRWDefault;
|
||||
if (signature.size() > 0) {
|
||||
sig = signature[i].rw;
|
||||
}
|
||||
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
||||
<< std::string(py::repr(obj));
|
||||
// mix precision for non param
|
||||
bool is_cast = false;
|
||||
py::object cast_output;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor && meta_tensor->is_parameter()) {
|
||||
if (sig != SignatureEnumRW::kRWRead) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// redundant cast call if the tensor is a const Tensor.
|
||||
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
// mix precision for tuple inputs
|
||||
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
|
||||
}
|
||||
if (is_cast) {
|
||||
op_exec_info->op_inputs[i] = cast_output;
|
||||
}
|
||||
}
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
|
||||
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
|
||||
std::map<SignatureEnumDType, TypeId> dst_types;
|
||||
if (has_dtype_sig) {
|
||||
// fetch info for implicit cast
|
||||
auto type_indexes = GetTypeIndex(dtypes);
|
||||
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
|
||||
}
|
||||
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
|
||||
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
|
||||
}
|
||||
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
const auto &obj = op_exec_info->op_inputs[i];
|
||||
bool op_mask = false;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
|
@ -710,9 +793,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
op_mask = meta_tensor->is_parameter();
|
||||
}
|
||||
}
|
||||
|
||||
(*op_masks).push_back(op_mask);
|
||||
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
|
||||
MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ "
|
||||
<< grad_flag_;
|
||||
|
||||
AnfNodePtr node = nullptr;
|
||||
|
@ -727,6 +809,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
if (node != nullptr && node->abstract() != nullptr) {
|
||||
abs = node->abstract();
|
||||
}
|
||||
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
||||
<< prim->is_const_prim();
|
||||
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
|
||||
|
@ -998,7 +1084,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|||
}
|
||||
|
||||
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
||||
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
||||
MS_LOG(DEBUG) << "RunOp start " << args.size();
|
||||
OpExecInfoPtr op_exec_info = nullptr;
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto name = py::cast<std::string>(args[PY_NAME]);
|
||||
|
|
|
@ -455,6 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.set_dtype(mindspore.int32)
|
||||
mindspore.int32
|
||||
)mydelimiter")
|
||||
.def("set_cast_dtype", &Tensor::set_cast_dtype)
|
||||
.def("__str__", &Tensor::ToString)
|
||||
.def("__repr__", &Tensor::ToStringRepr)
|
||||
.def(py::pickle(
|
||||
|
|
|
@ -292,6 +292,7 @@ class _PynativeExecutor:
|
|||
def __init__(self):
|
||||
self._executor = PynativeExecutor_.get_instance()
|
||||
|
||||
#TODO(kpy):add a type arg
|
||||
def new_graph(self, obj, *args, **kwargs):
|
||||
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
||||
|
||||
|
|
|
@ -219,16 +219,6 @@ class Parameter(MetaTensor):
|
|||
raise ValueError("The type of the name should be `str` or `None`.")
|
||||
self._param_info.name = name_
|
||||
|
||||
@property
|
||||
def cast_type(self):
|
||||
return self._cast_type
|
||||
|
||||
@cast_type.setter
|
||||
def cast_type(self, dst_type):
|
||||
if dst_type not in (mstype.float16, mstype.float32, None):
|
||||
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
|
||||
self._cast_type = dst_type
|
||||
|
||||
@property
|
||||
def sliced(self):
|
||||
"""Get slice status of the parameter."""
|
||||
|
|
|
@ -268,6 +268,8 @@ class Tensor : public MetaTensor {
|
|||
std::vector<Axis> padding_type() const { return padding_type_; }
|
||||
|
||||
std::string id() const { return id_; }
|
||||
TypePtr cast_dtype() { return cast_dtype_; }
|
||||
void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; }
|
||||
|
||||
void SetNeedWait(bool need_wait) {
|
||||
if (event_ != nullptr) {
|
||||
|
@ -310,6 +312,7 @@ class Tensor : public MetaTensor {
|
|||
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
|
||||
DeviceSyncPtr device_sync_{nullptr};
|
||||
std::vector<Axis> padding_type_;
|
||||
TypePtr cast_dtype_{nullptr};
|
||||
};
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
|
|
@ -61,15 +61,17 @@ class Cell(Cell_):
|
|||
"""
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode',
|
||||
'_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced',
|
||||
'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
|
||||
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
|
||||
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
|
||||
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
|
||||
'_auto_parallel_compile_and_run', 'cell_type']
|
||||
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
Cell_.__init__(self, self._cell_tag)
|
||||
self._params = OrderedDict()
|
||||
self._cells = OrderedDict()
|
||||
self._params_list = OrderedDict()
|
||||
self._tensor_list = OrderedDict()
|
||||
self.training = False
|
||||
self.requires_grad = False
|
||||
self.pynative = False
|
||||
|
@ -228,6 +230,9 @@ class Cell(Cell_):
|
|||
return cells[name]
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
||||
params_list = self.__dict__['_params_list']
|
||||
tensor_list = self.__dict__['_tensor_list']
|
||||
if name in tensor_list:
|
||||
return self.cast_param(tensor_list[name])
|
||||
if name in params_list:
|
||||
para_list = params_list[name]
|
||||
cast_list = list()
|
||||
|
@ -339,6 +344,7 @@ class Cell(Cell_):
|
|||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
tensor_list = self.__dict__.get('_tensor_list')
|
||||
if isinstance(value, Parameter):
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
|
@ -383,6 +389,13 @@ class Cell(Cell_):
|
|||
if value is not None:
|
||||
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
|
||||
self._cells[name] = None
|
||||
elif isinstance(value, Tensor):
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
tensor_list[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if isinstance(value, Primitive):
|
||||
value.set_prim_instance_name(name)
|
||||
|
@ -570,11 +583,9 @@ class Cell(Cell_):
|
|||
"""
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
param.cast_type = mstype.float16
|
||||
elif self._mindspore_flags.get('fp32'):
|
||||
param.cast_type = mstype.float32
|
||||
else:
|
||||
param.cast_type = None
|
||||
param.set_cast_dtype(mstype.float16)
|
||||
if self._mindspore_flags.get('fp32'):
|
||||
param.set_cast_dtype(mstype.float32)
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import inspect
|
||||
import copy
|
||||
from mindspore.common.api import _wrap_func
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore import context
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from . import signature as sig
|
||||
|
@ -496,16 +495,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
@_wrap_func
|
||||
def _run_op(obj, op_name, args):
|
||||
"""Single op execution function supported by ge in PyNative mode."""
|
||||
cast = tensor_operator_registry.get("cast")
|
||||
if op_name == "Cast" or obj.update_parameter:
|
||||
cast_args = args
|
||||
else:
|
||||
cast_args = list(args)
|
||||
for idx, arg in enumerate(args):
|
||||
cast_type = getattr(arg, "cast_type", None)
|
||||
if cast_type:
|
||||
cast_args[idx] = cast(arg, cast_type)
|
||||
output = real_run_op(obj, op_name, cast_args)
|
||||
output = real_run_op(obj, op_name, args)
|
||||
if not output:
|
||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||
if len(output) == 1:
|
||||
|
|
|
@ -74,7 +74,7 @@ def test_add_cast_flag():
|
|||
net.fc3.to_float(mstype.float32)
|
||||
net = train_step_with_loss_warp(net)
|
||||
net.set_train()
|
||||
_executor.compile(net, predict, label)
|
||||
net(predict, label)
|
||||
|
||||
|
||||
def test_add_cast_flag_tensor():
|
||||
|
@ -82,7 +82,7 @@ def test_add_cast_flag_tensor():
|
|||
net = NetForConcat()
|
||||
net.add_flags_recursive(fp16=True)
|
||||
net.set_train()
|
||||
_executor.compile(net, x1)
|
||||
net(x1)
|
||||
|
||||
|
||||
def test_on_momentum():
|
||||
|
@ -91,7 +91,7 @@ def test_on_momentum():
|
|||
net = LeNet5()
|
||||
net = train_step_with_loss_warp(net).to_float(mstype.float16)
|
||||
net.set_train()
|
||||
_executor.compile(net, predict, label)
|
||||
net(predict, label)
|
||||
|
||||
|
||||
def test_data_parallel_with_cast():
|
||||
|
@ -134,7 +134,6 @@ def test_nn_prelu():
|
|||
class NetForCast(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetForCast, self).__init__()
|
||||
self.concat = P.Concat()
|
||||
self.x1 = Tensor(1.0, mstype.float32)
|
||||
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
|
||||
|
||||
|
@ -144,11 +143,10 @@ class NetForCast(nn.Cell):
|
|||
|
||||
|
||||
def test_cast():
|
||||
context.set_context(save_graphs=True)
|
||||
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
|
||||
net = NetForCast()
|
||||
net.add_flags_recursive(fp16=True)
|
||||
_executor.compile(net, x)
|
||||
net(x)
|
||||
|
||||
|
||||
class IRBlockZ(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue