do mixprecision in c++ for pynative

This commit is contained in:
kpy 2020-08-31 15:28:03 +08:00 committed by Wei Luning
parent 3bc3f8ed8e
commit 44c01e27c0
8 changed files with 149 additions and 69 deletions

View File

@ -256,41 +256,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
return RunOp(args)[0];
}
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
auto &out_args = op_exec_info->op_inputs;
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return;
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
auto tensor = py::cast<tensor::TensorPtr>(obj);
auto cast_type = tensor->cast_dtype();
py::object cast_output;
if (cast_type != nullptr) {
auto source_element = tensor->Dtype();
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
MS_LOG(DEBUG) << "cast to " << cast_type->ToString();
cast_output = DoAutoCast(obj, cast_type->type_id());
*is_cast = true;
}
}
auto type_indexes = GetTypeIndex(dtypes);
auto dst_type = GetDstType(out_args, type_indexes);
return cast_output;
}
for (size_t i = 0; i < dtypes.size(); ++i) {
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
auto tuple_size = static_cast<int>(tuple.size());
py::tuple result(tuple_size);
for (int i = 0; i < tuple_size; i++) {
if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
MS_LOG(DEBUG) << "call cast for item " << i;
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
} else if (py::isinstance<py::tuple>(tuple[i])) {
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
}
}
return result;
}
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
auto signature = prim->signatures();
bool has_sig_dtype = false;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
[&has_sig_dtype](const Signature &sig) {
auto dtype = sig.dtype;
if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
has_sig_dtype = true;
}
return dtype;
});
return has_sig_dtype;
}
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) {
const auto &signature = prim->signatures();
auto &out_args = op_exec_info->op_inputs;
bool has_dtype_sig = (dtypes.size() > 0);
for (size_t i = 0; i < out_args.size(); ++i) {
MS_LOG(DEBUG) << "check inputs " << i;
auto obj = out_args[i];
auto sig = SignatureEnumRW::kRWDefault;
if (signature.size() > 0) {
sig = signature[i].rw;
}
bool is_parameter = false;
TypeId arg_type_id = kTypeUnknown;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
if (arg->is_parameter()) {
is_parameter = true;
MS_LOG(DEBUG) << "parameter is read " << i;
}
arg_type_id = arg->data_type();
}
// No need to implicit cast if no dtype.
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
continue;
}
auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) {
continue;
}
auto obj = out_args[i];
auto sig = signature[i].rw;
bool is_parameter = false;
// implicit cast
bool is_same_type = false;
TypeId arg_type_id = kTypeUnknown;
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
if (arg->is_parameter()) {
is_parameter = true;
}
arg_type_id = arg->data_type();
}
if (arg_type_id != 0) {
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
}
@ -317,7 +360,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
}
py::object cast_output = DoAutoCast(out_args[i], it->second);
out_args[i] = cast_output;
ValuePtr input_value = PyAttrValue(cast_output);
}
}
@ -346,7 +388,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = args[PY_INPUTS];
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
return op_exec_info;
}
@ -697,11 +738,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs.push_back(NewValueNode(prim));
size_t size = op_exec_info->op_inputs.size();
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
// ignore signature for cast op
bool is_cast_op = (op_exec_info->op_name == "Cast");
if (!is_cast_op) {
const auto &signature = prim->signatures();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
auto sig = SignatureEnumRW::kRWDefault;
if (signature.size() > 0) {
sig = signature[i].rw;
}
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
<< std::string(py::repr(obj));
// mix precision for non param
bool is_cast = false;
py::object cast_output;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor && meta_tensor->is_parameter()) {
if (sig != SignatureEnumRW::kRWRead) {
continue;
}
}
// redundant cast call if the tensor is a const Tensor.
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
} else if (py::isinstance<py::tuple>(obj)) {
// mix precision for tuple inputs
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
}
if (is_cast) {
op_exec_info->op_inputs[i] = cast_output;
}
}
std::vector<SignatureEnumDType> dtypes;
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
std::map<SignatureEnumDType, TypeId> dst_types;
if (has_dtype_sig) {
// fetch info for implicit cast
auto type_indexes = GetTypeIndex(dtypes);
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
}
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
}
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
const auto &obj = op_exec_info->op_inputs[i];
bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
@ -709,9 +792,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
op_mask = meta_tensor->is_parameter();
}
}
(*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ "
<< grad_flag_;
AnfNodePtr node = nullptr;
@ -726,6 +808,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (node != nullptr && node->abstract() != nullptr) {
abs = node->abstract();
}
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_prim();
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
@ -926,7 +1012,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
}
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
MS_LOG(DEBUG) << "RunOp start " << args.size();
OpExecInfoPtr op_exec_info = nullptr;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto name = py::cast<std::string>(args[PY_NAME]);

View File

@ -451,6 +451,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter")
.def("set_cast_dtype", &Tensor::set_cast_dtype)
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle(

View File

@ -292,6 +292,7 @@ class _PynativeExecutor:
def __init__(self):
self._executor = PynativeExecutor_.get_instance()
#TODO(kpy):add a type arg
def new_graph(self, obj, *args, **kwargs):
self._executor.new_graph(obj, *args, *(kwargs.values()))

View File

@ -216,16 +216,6 @@ class Parameter(MetaTensor):
raise ValueError("The type of the name should be `str` or `None`.")
self._param_info.name = name_
@property
def cast_type(self):
return self._cast_type
@cast_type.setter
def cast_type(self, dst_type):
if dst_type not in (mstype.float16, mstype.float32, None):
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
self._cast_type = dst_type
@property
def sliced(self):
"""Get slice status of the parameter."""

View File

@ -268,6 +268,8 @@ class Tensor : public MetaTensor {
std::vector<Axis> padding_type() const { return padding_type_; }
std::string id() const { return id_; }
TypePtr cast_dtype() { return cast_dtype_; }
void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; }
void SetNeedWait(bool need_wait) {
if (event_ != nullptr) {
@ -310,6 +312,7 @@ class Tensor : public MetaTensor {
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
TypePtr cast_dtype_{nullptr};
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;

View File

@ -61,15 +61,17 @@ class Cell(Cell_):
"""
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode',
'_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced',
'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
'_auto_parallel_compile_and_run', 'cell_type']
def __init__(self, auto_prefix=True, flags=None):
Cell_.__init__(self, self._cell_tag)
self._params = OrderedDict()
self._cells = OrderedDict()
self._params_list = OrderedDict()
self._tensor_list = OrderedDict()
self.training = False
self.requires_grad = False
self.pynative = False
@ -228,6 +230,9 @@ class Cell(Cell_):
return cells[name]
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
params_list = self.__dict__['_params_list']
tensor_list = self.__dict__['_tensor_list']
if name in tensor_list:
return self.cast_param(tensor_list[name])
if name in params_list:
para_list = params_list[name]
cast_list = list()
@ -339,6 +344,7 @@ class Cell(Cell_):
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
tensor_list = self.__dict__.get('_tensor_list')
if isinstance(value, Parameter):
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
@ -383,6 +389,13 @@ class Cell(Cell_):
if value is not None:
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
self._cells[name] = None
elif isinstance(value, Tensor):
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
tensor_list[name] = value
else:
object.__setattr__(self, name, value)
else:
if isinstance(value, Primitive):
value.set_prim_instance_name(name)
@ -570,11 +583,9 @@ class Cell(Cell_):
"""
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
param.cast_type = mstype.float16
elif self._mindspore_flags.get('fp32'):
param.cast_type = mstype.float32
else:
param.cast_type = None
param.set_cast_dtype(mstype.float16)
if self._mindspore_flags.get('fp32'):
param.set_cast_dtype(mstype.float32)
return param
def insert_child_to_cell(self, child_name, child):

View File

@ -17,7 +17,6 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from . import signature as sig
@ -490,16 +489,7 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
cast = tensor_operator_registry.get("cast")
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list(args)
for idx, arg in enumerate(args):
cast_type = getattr(arg, "cast_type", None)
if cast_type:
cast_args[idx] = cast(arg, cast_type)
output = real_run_op(obj, op_name, cast_args)
output = real_run_op(obj, op_name, args)
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:

View File

@ -74,7 +74,7 @@ def test_add_cast_flag():
net.fc3.to_float(mstype.float32)
net = train_step_with_loss_warp(net)
net.set_train()
_executor.compile(net, predict, label)
net(predict, label)
def test_add_cast_flag_tensor():
@ -82,7 +82,7 @@ def test_add_cast_flag_tensor():
net = NetForConcat()
net.add_flags_recursive(fp16=True)
net.set_train()
_executor.compile(net, x1)
net(x1)
def test_on_momentum():
@ -91,7 +91,7 @@ def test_on_momentum():
net = LeNet5()
net = train_step_with_loss_warp(net).to_float(mstype.float16)
net.set_train()
_executor.compile(net, predict, label)
net(predict, label)
def test_data_parallel_with_cast():
@ -134,7 +134,6 @@ def test_nn_prelu():
class NetForCast(nn.Cell):
def __init__(self):
super(NetForCast, self).__init__()
self.concat = P.Concat()
self.x1 = Tensor(1.0, mstype.float32)
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
@ -144,11 +143,10 @@ class NetForCast(nn.Cell):
def test_cast():
context.set_context(save_graphs=True)
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
net = NetForCast()
net.add_flags_recursive(fp16=True)
_executor.compile(net, x)
net(x)
class IRBlockZ(nn.Cell):