diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c6eb673624..50684c5d54 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -288,6 +288,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { } bool GetSignatureType(const PrimitivePyPtr &prim, std::vector *dtypes) { + MS_EXCEPTION_IF_NULL(dtypes); auto signature = prim->signatures(); bool has_sig_dtype = false; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), @@ -733,20 +734,29 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list) { + MS_EXCEPTION_IF_NULL(op_masks); + MS_EXCEPTION_IF_NULL(args_spec_list); CNodePtr cnode = nullptr; std::vector inputs; + auto prim = op_exec_info->py_primitive; + const auto &signature = prim->signatures(); + inputs.push_back(NewValueNode(prim)); size_t size = op_exec_info->op_inputs.size(); + auto sig_size = signature.size(); // ignore signature for cast op + if (sig_size > 0 && sig_size != size) { + MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " + << "inputs size " << sig_size; + } bool is_cast_op = (op_exec_info->op_name == "Cast"); if (!is_cast_op) { - const auto &signature = prim->signatures(); for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; auto sig = SignatureEnumRW::kRWDefault; - if (signature.size() > 0) { + if (sig_size > 0) { sig = signature[i].rw; } MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index f9636aacb9..7db7e22162 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -455,7 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.set_dtype(mindspore.int32) mindspore.int32 )mydelimiter") - .def("set_cast_dtype", &Tensor::set_cast_dtype) + .def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr) .def("__str__", &Tensor::ToString) .def("__repr__", &Tensor::ToStringRepr) .def(py::pickle( diff --git a/mindspore/common/api.py b/mindspore/common/api.py index e099ed18d5..279a93dc8a 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -292,7 +292,6 @@ class _PynativeExecutor: def __init__(self): self._executor = PynativeExecutor_.get_instance() - #TODO(kpy):add a type arg def new_graph(self, obj, *args, **kwargs): self._executor.new_graph(obj, *args, *(kwargs.values())) diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 8689980268..94c0329296 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -269,7 +269,7 @@ class Tensor : public MetaTensor { std::string id() const { return id_; } TypePtr cast_dtype() { return cast_dtype_; } - void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; } + void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; } void SetNeedWait(bool need_wait) { if (event_ != nullptr) { diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e509815587..ee1832ab94 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -582,10 +582,13 @@ class Cell(Cell_): param (Parameter): The parameter to cast. """ if hasattr(self, "_mindspore_flags"): - if self._mindspore_flags.get('fp16'): - param.set_cast_dtype(mstype.float16) if self._mindspore_flags.get('fp32'): param.set_cast_dtype(mstype.float32) + elif self._mindspore_flags.get('fp16'): + param.set_cast_dtype(mstype.float16) + else: + # retest dtype + param.set_cast_dtype() return param def insert_child_to_cell(self, child_name, child): diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 0c84bd1ef5..db93c3a5ec 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -464,7 +464,7 @@ raise_set = [ 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), 'desc_inputs': [0]}), ('AssignAdd_Error', { - 'block': (P.AssignAdd(), {'exception': IndexError}), + 'block': (P.AssignAdd(), {'exception': ValueError}), 'desc_inputs': [[1]]}), ]