forked from OSSInnovation/mindspore
!6320 change mix_precision to c++
Merge pull request !6320 from vlne-v1/pynative_amp
This commit is contained in:
commit
3aa07a4362
|
@ -257,41 +257,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
||||||
return RunOp(args)[0];
|
return RunOp(args)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
|
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
||||||
auto &out_args = op_exec_info->op_inputs;
|
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
||||||
auto signature = prim->signatures();
|
auto cast_type = tensor->cast_dtype();
|
||||||
std::vector<SignatureEnumDType> dtypes;
|
py::object cast_output;
|
||||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
if (cast_type != nullptr) {
|
||||||
[](const Signature &sig) { return sig.dtype; });
|
auto source_element = tensor->Dtype();
|
||||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
||||||
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
MS_LOG(DEBUG) << "cast to " << cast_type->ToString();
|
||||||
return;
|
cast_output = DoAutoCast(obj, cast_type->type_id());
|
||||||
|
*is_cast = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
auto type_indexes = GetTypeIndex(dtypes);
|
return cast_output;
|
||||||
auto dst_type = GetDstType(out_args, type_indexes);
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
||||||
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
auto tuple_size = static_cast<int>(tuple.size());
|
||||||
|
py::tuple result(tuple_size);
|
||||||
|
|
||||||
|
for (int i = 0; i < tuple_size; i++) {
|
||||||
|
if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
|
||||||
|
MS_LOG(DEBUG) << "call cast for item " << i;
|
||||||
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
||||||
|
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
||||||
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
|
||||||
|
auto signature = prim->signatures();
|
||||||
|
bool has_sig_dtype = false;
|
||||||
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
|
||||||
|
[&has_sig_dtype](const Signature &sig) {
|
||||||
|
auto dtype = sig.dtype;
|
||||||
|
if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||||
|
has_sig_dtype = true;
|
||||||
|
}
|
||||||
|
return dtype;
|
||||||
|
});
|
||||||
|
return has_sig_dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
||||||
|
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) {
|
||||||
|
const auto &signature = prim->signatures();
|
||||||
|
auto &out_args = op_exec_info->op_inputs;
|
||||||
|
bool has_dtype_sig = (dtypes.size() > 0);
|
||||||
|
for (size_t i = 0; i < out_args.size(); ++i) {
|
||||||
|
MS_LOG(DEBUG) << "check inputs " << i;
|
||||||
|
auto obj = out_args[i];
|
||||||
|
auto sig = SignatureEnumRW::kRWDefault;
|
||||||
|
if (signature.size() > 0) {
|
||||||
|
sig = signature[i].rw;
|
||||||
|
}
|
||||||
|
bool is_parameter = false;
|
||||||
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||||
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
|
||||||
|
if (arg->is_parameter()) {
|
||||||
|
is_parameter = true;
|
||||||
|
MS_LOG(DEBUG) << "parameter is read " << i;
|
||||||
|
}
|
||||||
|
arg_type_id = arg->data_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to implicit cast if no dtype.
|
||||||
|
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto it = dst_type.find(dtypes[i]);
|
auto it = dst_type.find(dtypes[i]);
|
||||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// implicit cast
|
||||||
auto obj = out_args[i];
|
|
||||||
auto sig = signature[i].rw;
|
|
||||||
bool is_parameter = false;
|
|
||||||
bool is_same_type = false;
|
bool is_same_type = false;
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
|
||||||
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite);
|
||||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
|
||||||
auto arg = py::cast<tensor::MetaTensorPtr>(obj);
|
|
||||||
if (arg->is_parameter()) {
|
|
||||||
is_parameter = true;
|
|
||||||
}
|
|
||||||
arg_type_id = arg->data_type();
|
|
||||||
}
|
|
||||||
if (arg_type_id != 0) {
|
if (arg_type_id != 0) {
|
||||||
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
|
||||||
}
|
}
|
||||||
|
@ -318,7 +361,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
|
||||||
}
|
}
|
||||||
py::object cast_output = DoAutoCast(out_args[i], it->second);
|
py::object cast_output = DoAutoCast(out_args[i], it->second);
|
||||||
out_args[i] = cast_output;
|
out_args[i] = cast_output;
|
||||||
ValuePtr input_value = PyAttrValue(cast_output);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -347,7 +389,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
||||||
op_exec_info->py_primitive = prim;
|
op_exec_info->py_primitive = prim;
|
||||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||||
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
|
|
||||||
return op_exec_info;
|
return op_exec_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -698,11 +739,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
||||||
inputs.push_back(NewValueNode(prim));
|
inputs.push_back(NewValueNode(prim));
|
||||||
|
|
||||||
size_t size = op_exec_info->op_inputs.size();
|
size_t size = op_exec_info->op_inputs.size();
|
||||||
auto const_input_index = prim->get_const_input_indexes();
|
// ignore signature for cast op
|
||||||
bool have_const_input = !const_input_index.empty();
|
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
||||||
bool is_const_prim = prim->is_const_prim();
|
if (!is_cast_op) {
|
||||||
|
const auto &signature = prim->signatures();
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
auto obj = op_exec_info->op_inputs[i];
|
||||||
|
auto sig = SignatureEnumRW::kRWDefault;
|
||||||
|
if (signature.size() > 0) {
|
||||||
|
sig = signature[i].rw;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
||||||
|
<< std::string(py::repr(obj));
|
||||||
|
// mix precision for non param
|
||||||
|
bool is_cast = false;
|
||||||
|
py::object cast_output;
|
||||||
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||||
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||||
|
if (meta_tensor && meta_tensor->is_parameter()) {
|
||||||
|
if (sig != SignatureEnumRW::kRWRead) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// redundant cast call if the tensor is a const Tensor.
|
||||||
|
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
|
||||||
|
} else if (py::isinstance<py::tuple>(obj)) {
|
||||||
|
// mix precision for tuple inputs
|
||||||
|
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
|
||||||
|
}
|
||||||
|
if (is_cast) {
|
||||||
|
op_exec_info->op_inputs[i] = cast_output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
|
|
||||||
|
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
|
||||||
|
std::map<SignatureEnumDType, TypeId> dst_types;
|
||||||
|
if (has_dtype_sig) {
|
||||||
|
// fetch info for implicit cast
|
||||||
|
auto type_indexes = GetTypeIndex(dtypes);
|
||||||
|
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
|
||||||
|
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
auto obj = op_exec_info->op_inputs[i];
|
const auto &obj = op_exec_info->op_inputs[i];
|
||||||
bool op_mask = false;
|
bool op_mask = false;
|
||||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||||
|
@ -710,9 +793,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
||||||
op_mask = meta_tensor->is_parameter();
|
op_mask = meta_tensor->is_parameter();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(*op_masks).push_back(op_mask);
|
(*op_masks).push_back(op_mask);
|
||||||
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
|
MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ "
|
||||||
<< grad_flag_;
|
<< grad_flag_;
|
||||||
|
|
||||||
AnfNodePtr node = nullptr;
|
AnfNodePtr node = nullptr;
|
||||||
|
@ -727,6 +809,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
||||||
if (node != nullptr && node->abstract() != nullptr) {
|
if (node != nullptr && node->abstract() != nullptr) {
|
||||||
abs = node->abstract();
|
abs = node->abstract();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto const_input_index = prim->get_const_input_indexes();
|
||||||
|
bool have_const_input = !const_input_index.empty();
|
||||||
|
bool is_const_prim = prim->is_const_prim();
|
||||||
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
||||||
<< prim->is_const_prim();
|
<< prim->is_const_prim();
|
||||||
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
|
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i);
|
||||||
|
@ -998,7 +1084,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
||||||
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
MS_LOG(DEBUG) << "RunOp start " << args.size();
|
||||||
OpExecInfoPtr op_exec_info = nullptr;
|
OpExecInfoPtr op_exec_info = nullptr;
|
||||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||||
auto name = py::cast<std::string>(args[PY_NAME]);
|
auto name = py::cast<std::string>(args[PY_NAME]);
|
||||||
|
|
|
@ -455,6 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
>>> data.set_dtype(mindspore.int32)
|
>>> data.set_dtype(mindspore.int32)
|
||||||
mindspore.int32
|
mindspore.int32
|
||||||
)mydelimiter")
|
)mydelimiter")
|
||||||
|
.def("set_cast_dtype", &Tensor::set_cast_dtype)
|
||||||
.def("__str__", &Tensor::ToString)
|
.def("__str__", &Tensor::ToString)
|
||||||
.def("__repr__", &Tensor::ToStringRepr)
|
.def("__repr__", &Tensor::ToStringRepr)
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
|
|
|
@ -292,6 +292,7 @@ class _PynativeExecutor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._executor = PynativeExecutor_.get_instance()
|
self._executor = PynativeExecutor_.get_instance()
|
||||||
|
|
||||||
|
#TODO(kpy):add a type arg
|
||||||
def new_graph(self, obj, *args, **kwargs):
|
def new_graph(self, obj, *args, **kwargs):
|
||||||
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
||||||
|
|
||||||
|
|
|
@ -219,16 +219,6 @@ class Parameter(MetaTensor):
|
||||||
raise ValueError("The type of the name should be `str` or `None`.")
|
raise ValueError("The type of the name should be `str` or `None`.")
|
||||||
self._param_info.name = name_
|
self._param_info.name = name_
|
||||||
|
|
||||||
@property
|
|
||||||
def cast_type(self):
|
|
||||||
return self._cast_type
|
|
||||||
|
|
||||||
@cast_type.setter
|
|
||||||
def cast_type(self, dst_type):
|
|
||||||
if dst_type not in (mstype.float16, mstype.float32, None):
|
|
||||||
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
|
|
||||||
self._cast_type = dst_type
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sliced(self):
|
def sliced(self):
|
||||||
"""Get slice status of the parameter."""
|
"""Get slice status of the parameter."""
|
||||||
|
|
|
@ -268,6 +268,8 @@ class Tensor : public MetaTensor {
|
||||||
std::vector<Axis> padding_type() const { return padding_type_; }
|
std::vector<Axis> padding_type() const { return padding_type_; }
|
||||||
|
|
||||||
std::string id() const { return id_; }
|
std::string id() const { return id_; }
|
||||||
|
TypePtr cast_dtype() { return cast_dtype_; }
|
||||||
|
void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; }
|
||||||
|
|
||||||
void SetNeedWait(bool need_wait) {
|
void SetNeedWait(bool need_wait) {
|
||||||
if (event_ != nullptr) {
|
if (event_ != nullptr) {
|
||||||
|
@ -310,6 +312,7 @@ class Tensor : public MetaTensor {
|
||||||
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
|
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
|
||||||
DeviceSyncPtr device_sync_{nullptr};
|
DeviceSyncPtr device_sync_{nullptr};
|
||||||
std::vector<Axis> padding_type_;
|
std::vector<Axis> padding_type_;
|
||||||
|
TypePtr cast_dtype_{nullptr};
|
||||||
};
|
};
|
||||||
using TensorPtr = std::shared_ptr<Tensor>;
|
using TensorPtr = std::shared_ptr<Tensor>;
|
||||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||||
|
|
|
@ -61,15 +61,17 @@ class Cell(Cell_):
|
||||||
"""
|
"""
|
||||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||||
'_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode',
|
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
|
||||||
'_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced',
|
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
|
||||||
'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
|
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
|
||||||
|
'_auto_parallel_compile_and_run', 'cell_type']
|
||||||
|
|
||||||
def __init__(self, auto_prefix=True, flags=None):
|
def __init__(self, auto_prefix=True, flags=None):
|
||||||
Cell_.__init__(self, self._cell_tag)
|
Cell_.__init__(self, self._cell_tag)
|
||||||
self._params = OrderedDict()
|
self._params = OrderedDict()
|
||||||
self._cells = OrderedDict()
|
self._cells = OrderedDict()
|
||||||
self._params_list = OrderedDict()
|
self._params_list = OrderedDict()
|
||||||
|
self._tensor_list = OrderedDict()
|
||||||
self.training = False
|
self.training = False
|
||||||
self.requires_grad = False
|
self.requires_grad = False
|
||||||
self.pynative = False
|
self.pynative = False
|
||||||
|
@ -228,6 +230,9 @@ class Cell(Cell_):
|
||||||
return cells[name]
|
return cells[name]
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
|
||||||
params_list = self.__dict__['_params_list']
|
params_list = self.__dict__['_params_list']
|
||||||
|
tensor_list = self.__dict__['_tensor_list']
|
||||||
|
if name in tensor_list:
|
||||||
|
return self.cast_param(tensor_list[name])
|
||||||
if name in params_list:
|
if name in params_list:
|
||||||
para_list = params_list[name]
|
para_list = params_list[name]
|
||||||
cast_list = list()
|
cast_list = list()
|
||||||
|
@ -339,6 +344,7 @@ class Cell(Cell_):
|
||||||
cells = self.__dict__.get('_cells')
|
cells = self.__dict__.get('_cells')
|
||||||
params = self.__dict__.get('_params')
|
params = self.__dict__.get('_params')
|
||||||
params_list = self.__dict__.get('_params_list')
|
params_list = self.__dict__.get('_params_list')
|
||||||
|
tensor_list = self.__dict__.get('_tensor_list')
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
if params is None:
|
if params is None:
|
||||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||||
|
@ -383,6 +389,13 @@ class Cell(Cell_):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
|
raise TypeError("Expected type is cell, but got {}.".format(type(value)))
|
||||||
self._cells[name] = None
|
self._cells[name] = None
|
||||||
|
elif isinstance(value, Tensor):
|
||||||
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
|
if name in self.__dict__:
|
||||||
|
del self.__dict__[name]
|
||||||
|
tensor_list[name] = value
|
||||||
|
else:
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
else:
|
else:
|
||||||
if isinstance(value, Primitive):
|
if isinstance(value, Primitive):
|
||||||
value.set_prim_instance_name(name)
|
value.set_prim_instance_name(name)
|
||||||
|
@ -570,11 +583,9 @@ class Cell(Cell_):
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "_mindspore_flags"):
|
if hasattr(self, "_mindspore_flags"):
|
||||||
if self._mindspore_flags.get('fp16'):
|
if self._mindspore_flags.get('fp16'):
|
||||||
param.cast_type = mstype.float16
|
param.set_cast_dtype(mstype.float16)
|
||||||
elif self._mindspore_flags.get('fp32'):
|
if self._mindspore_flags.get('fp32'):
|
||||||
param.cast_type = mstype.float32
|
param.set_cast_dtype(mstype.float32)
|
||||||
else:
|
|
||||||
param.cast_type = None
|
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def insert_child_to_cell(self, child_name, child):
|
def insert_child_to_cell(self, child_name, child):
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
import copy
|
import copy
|
||||||
from mindspore.common.api import _wrap_func
|
from mindspore.common.api import _wrap_func
|
||||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||||
from . import signature as sig
|
from . import signature as sig
|
||||||
|
@ -496,16 +495,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
||||||
@_wrap_func
|
@_wrap_func
|
||||||
def _run_op(obj, op_name, args):
|
def _run_op(obj, op_name, args):
|
||||||
"""Single op execution function supported by ge in PyNative mode."""
|
"""Single op execution function supported by ge in PyNative mode."""
|
||||||
cast = tensor_operator_registry.get("cast")
|
output = real_run_op(obj, op_name, args)
|
||||||
if op_name == "Cast" or obj.update_parameter:
|
|
||||||
cast_args = args
|
|
||||||
else:
|
|
||||||
cast_args = list(args)
|
|
||||||
for idx, arg in enumerate(args):
|
|
||||||
cast_type = getattr(arg, "cast_type", None)
|
|
||||||
if cast_type:
|
|
||||||
cast_args[idx] = cast(arg, cast_type)
|
|
||||||
output = real_run_op(obj, op_name, cast_args)
|
|
||||||
if not output:
|
if not output:
|
||||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||||
if len(output) == 1:
|
if len(output) == 1:
|
||||||
|
|
|
@ -74,7 +74,7 @@ def test_add_cast_flag():
|
||||||
net.fc3.to_float(mstype.float32)
|
net.fc3.to_float(mstype.float32)
|
||||||
net = train_step_with_loss_warp(net)
|
net = train_step_with_loss_warp(net)
|
||||||
net.set_train()
|
net.set_train()
|
||||||
_executor.compile(net, predict, label)
|
net(predict, label)
|
||||||
|
|
||||||
|
|
||||||
def test_add_cast_flag_tensor():
|
def test_add_cast_flag_tensor():
|
||||||
|
@ -82,7 +82,7 @@ def test_add_cast_flag_tensor():
|
||||||
net = NetForConcat()
|
net = NetForConcat()
|
||||||
net.add_flags_recursive(fp16=True)
|
net.add_flags_recursive(fp16=True)
|
||||||
net.set_train()
|
net.set_train()
|
||||||
_executor.compile(net, x1)
|
net(x1)
|
||||||
|
|
||||||
|
|
||||||
def test_on_momentum():
|
def test_on_momentum():
|
||||||
|
@ -91,7 +91,7 @@ def test_on_momentum():
|
||||||
net = LeNet5()
|
net = LeNet5()
|
||||||
net = train_step_with_loss_warp(net).to_float(mstype.float16)
|
net = train_step_with_loss_warp(net).to_float(mstype.float16)
|
||||||
net.set_train()
|
net.set_train()
|
||||||
_executor.compile(net, predict, label)
|
net(predict, label)
|
||||||
|
|
||||||
|
|
||||||
def test_data_parallel_with_cast():
|
def test_data_parallel_with_cast():
|
||||||
|
@ -134,7 +134,6 @@ def test_nn_prelu():
|
||||||
class NetForCast(nn.Cell):
|
class NetForCast(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NetForCast, self).__init__()
|
super(NetForCast, self).__init__()
|
||||||
self.concat = P.Concat()
|
|
||||||
self.x1 = Tensor(1.0, mstype.float32)
|
self.x1 = Tensor(1.0, mstype.float32)
|
||||||
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
|
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
|
||||||
|
|
||||||
|
@ -144,11 +143,10 @@ class NetForCast(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
def test_cast():
|
def test_cast():
|
||||||
context.set_context(save_graphs=True)
|
|
||||||
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
|
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
|
||||||
net = NetForCast()
|
net = NetForCast()
|
||||||
net.add_flags_recursive(fp16=True)
|
net.add_flags_recursive(fp16=True)
|
||||||
_executor.compile(net, x)
|
net(x)
|
||||||
|
|
||||||
|
|
||||||
class IRBlockZ(nn.Cell):
|
class IRBlockZ(nn.Cell):
|
||||||
|
|
Loading…
Reference in New Issue