forked from OSSInnovation/mindspore
!6612 [bug]fix bug in Parameter flag set in pynative amp && code style in pynative_exector.cc
Merge pull request !6612 from vlne-v1/code_style
This commit is contained in:
commit
6c26629404
|
@ -288,6 +288,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
|||
}
|
||||
|
||||
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
|
||||
MS_EXCEPTION_IF_NULL(dtypes);
|
||||
auto signature = prim->signatures();
|
||||
bool has_sig_dtype = false;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
|
||||
|
@ -733,20 +734,29 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
|
|||
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(op_masks);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list);
|
||||
CNodePtr cnode = nullptr;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
const auto &signature = prim->signatures();
|
||||
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
|
||||
size_t size = op_exec_info->op_inputs.size();
|
||||
auto sig_size = signature.size();
|
||||
// ignore signature for cast op
|
||||
if (sig_size > 0 && sig_size != size) {
|
||||
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
|
||||
<< "inputs size " << sig_size;
|
||||
}
|
||||
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
||||
if (!is_cast_op) {
|
||||
const auto &signature = prim->signatures();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto obj = op_exec_info->op_inputs[i];
|
||||
auto sig = SignatureEnumRW::kRWDefault;
|
||||
if (signature.size() > 0) {
|
||||
if (sig_size > 0) {
|
||||
sig = signature[i].rw;
|
||||
}
|
||||
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
||||
|
|
|
@ -455,7 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.set_dtype(mindspore.int32)
|
||||
mindspore.int32
|
||||
)mydelimiter")
|
||||
.def("set_cast_dtype", &Tensor::set_cast_dtype)
|
||||
.def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr)
|
||||
.def("__str__", &Tensor::ToString)
|
||||
.def("__repr__", &Tensor::ToStringRepr)
|
||||
.def(py::pickle(
|
||||
|
|
|
@ -292,7 +292,6 @@ class _PynativeExecutor:
|
|||
def __init__(self):
|
||||
self._executor = PynativeExecutor_.get_instance()
|
||||
|
||||
#TODO(kpy):add a type arg
|
||||
def new_graph(self, obj, *args, **kwargs):
|
||||
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
||||
|
||||
|
|
|
@ -269,7 +269,7 @@ class Tensor : public MetaTensor {
|
|||
|
||||
std::string id() const { return id_; }
|
||||
TypePtr cast_dtype() { return cast_dtype_; }
|
||||
void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; }
|
||||
void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }
|
||||
|
||||
void SetNeedWait(bool need_wait) {
|
||||
if (event_ != nullptr) {
|
||||
|
|
|
@ -582,10 +582,13 @@ class Cell(Cell_):
|
|||
param (Parameter): The parameter to cast.
|
||||
"""
|
||||
if hasattr(self, "_mindspore_flags"):
|
||||
if self._mindspore_flags.get('fp16'):
|
||||
param.set_cast_dtype(mstype.float16)
|
||||
if self._mindspore_flags.get('fp32'):
|
||||
param.set_cast_dtype(mstype.float32)
|
||||
elif self._mindspore_flags.get('fp16'):
|
||||
param.set_cast_dtype(mstype.float16)
|
||||
else:
|
||||
# retest dtype
|
||||
param.set_cast_dtype()
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
|
|
|
@ -464,7 +464,7 @@ raise_set = [
|
|||
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
|
||||
'desc_inputs': [0]}),
|
||||
('AssignAdd_Error', {
|
||||
'block': (P.AssignAdd(), {'exception': IndexError}),
|
||||
'block': (P.AssignAdd(), {'exception': ValueError}),
|
||||
'desc_inputs': [[1]]}),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue