From 144a35b17e076f9d5aa06c449dfb92be6b4cff41 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 20 Aug 2020 15:30:57 +0800 Subject: [PATCH] Adapt GatherV2 for dynamic shape --- .../pipeline/jit/static_analysis/prim.cc | 175 ++++++++++++------ .../ccsrc/pipeline/jit/static_analysis/prim.h | 2 + .../jit/static_analysis/static_analysis.cc | 12 +- mindspore/ccsrc/pipeline/jit/validator.cc | 4 + mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 21 ++- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 2 + mindspore/core/abstract/infer_functions.h | 11 ++ mindspore/core/abstract/prim_arrays.cc | 57 ++++++ mindspore/core/abstract/prim_maths.cc | 9 + mindspore/core/abstract/prim_nn.cc | 20 ++ .../core/abstract/primitive_infer_map.cc | 6 + mindspore/core/base/core_ops.h | 5 + mindspore/core/ir/primitive.h | 3 +- mindspore/core/utils/flags.cc | 15 ++ mindspore/core/utils/flags.h | 10 + mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 42 ++++- mindspore/ops/operations/math_ops.py | 10 +- mindspore/ops/operations/nn_ops.py | 18 +- mindspore/ops/primitive.py | 92 +++++++++ tests/ut/python/ir/test_row_tensor.py | 25 ++- tests/ut/python/ops/test_dynamic_shape.py | 109 +++++++++++ 22 files changed, 549 insertions(+), 102 deletions(-) create mode 100755 tests/ut/python/ops/test_dynamic_shape.py diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index f66ab771011..252c4c09d77 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper; std::unordered_set prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", "env_getitem"}; -EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; - } - } - prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); - prim_->EndRecordAddAttr(); - auto added_attrs = prim_->evaluate_added_attrs(); - auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); - return infer_result; -} - EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; @@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { py::dict dic; if (abs_base->isa()) { auto arg_tensor = dyn_cast(abs_base); - dic["shape"] = arg_tensor->shape()->shape(); + dic[ATTR_SHAPE] = arg_tensor->shape()->shape(); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { const auto &min_shape = arg_tensor->shape()->min_shape(); const auto &max_shape = arg_tensor->shape()->max_shape(); if (!min_shape.empty() && !max_shape.empty()) { - dic["min_shape"] = min_shape; - dic["max_shape"] = max_shape; + dic[ATTR_MIN_SHAPE] = min_shape; + dic[ATTR_MAX_SHAPE] = max_shape; } } - dic["dtype"] = arg_tensor->BuildType(); - dic["value"] = BuildValue(arg_tensor->BuildValue()); + dic[ATTR_DTYPE] = arg_tensor->BuildType(); + dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); - dic["shape"] = arg->shape()->shape(); - dic["dtype"] = arg->BuildType(); - dic["value"] = BuildValue(arg->BuildValue()); + dic[ATTR_SHAPE] = arg->shape()->shape(); + dic[ATTR_DTYPE] = arg->BuildType(); + dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); - dic["shape"] = arg->shape()->shape(); - dic["dtype"] = arg->BuildType(); - dic["value"] = BuildValue(arg->BuildValue()); + dic[ATTR_SHAPE] = arg->shape()->shape(); + dic[ATTR_DTYPE] = arg->BuildType(); + dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { ShapeVector shape; - dic["shape"] = shape; - dic["dtype"] = abs_base->BuildType(); - dic["value"] = BuildValue(abs_base->BuildValue()); + dic[ATTR_SHAPE] = shape; + dic[ATTR_DTYPE] = abs_base->BuildType(); + dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); } else if (abs_base->isa()) { auto arg_slice = dyn_cast(abs_base); ShapeVector shape; - dic["shape"] = shape; - dic["dtype"] = arg_slice->BuildType(); - dic["value"] = BuildValue(arg_slice->BuildValue()); + dic[ATTR_SHAPE] = shape; + dic[ATTR_DTYPE] = arg_slice->BuildType(); + dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { auto value = abs_base->cast()->ref(); dic = ConvertAbstractToPython(value); } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::ellipsis(); - dic["value"] = py::ellipsis(); + dic[ATTR_SHAPE] = py::none(); + dic[ATTR_DTYPE] = py::ellipsis(); + dic[ATTR_VALUE] = py::ellipsis(); } else if (abs_base->isa()) { auto arg_tuple = dyn_cast(abs_base); size_t len = arg_tuple->size(); @@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { for (size_t i = 0; i < len; i++) { py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); - shape_tuple[i] = out["shape"]; - dtype_tuple[i] = out["dtype"]; + shape_tuple[i] = out[ATTR_SHAPE]; + dtype_tuple[i] = out[ATTR_DTYPE]; } - dic["shape"] = shape_tuple; - dic["dtype"] = dtype_tuple; - dic["value"] = BuildValue(arg_tuple->BuildValue()); + dic[ATTR_SHAPE] = shape_tuple; + dic[ATTR_DTYPE] = dtype_tuple; + dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); } else if (abs_base->isa()) { auto arg_list = dyn_cast(abs_base); size_t len = arg_list->size(); @@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { for (size_t i = 0; i < len; i++) { py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); - shape_list[i] = out["shape"]; - dtype_list[i] = out["dtype"]; + shape_list[i] = out[ATTR_SHAPE]; + dtype_list[i] = out[ATTR_DTYPE]; } - dic["shape"] = shape_list; - dic["dtype"] = dtype_list; - dic["value"] = BuildValue(arg_list->BuildValue()); + dic[ATTR_SHAPE] = shape_list; + dic[ATTR_DTYPE] = dtype_list; + dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::none(); - dic["value"] = py::none(); + dic[ATTR_SHAPE] = py::none(); + dic[ATTR_DTYPE] = py::none(); + dic[ATTR_VALUE] = py::none(); } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = abs_base->BuildType(); - dic["value"] = py::none(); + dic[ATTR_SHAPE] = py::none(); + dic[ATTR_DTYPE] = abs_base->BuildType(); + dic[ATTR_VALUE] = py::none(); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); - dic["shape"] = py::none(); - dic["dtype"] = arg->BuildType(); - dic["value"] = py::none(); + dic[ATTR_SHAPE] = py::none(); + dic[ATTR_DTYPE] = arg->BuildType(); + dic[ATTR_VALUE] = py::none(); } else { auto value = abs_base->BuildValue(); if ((*value == *kAnyValue)) { @@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { // Convert to AbstractValue based on type and shape - auto out_dtype = output["dtype"]; - if (output["value"].is_none()) { - auto out_shape = output["shape"]; - py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); - py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); + auto out_dtype = output[ATTR_DTYPE]; + if (output[ATTR_VALUE].is_none()) { + auto out_shape = output[ATTR_SHAPE]; + py::object min_shape = + output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none(); + py::object max_shape = + output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none(); return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; TypePtr dtype = py::isinstance(out_dtype) ? out_dtype.cast() : nullptr; - bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype); + bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype); if (!converted) { MS_LOG(EXCEPTION) << "Convert data failed"; } @@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } } // end anonymous namespace +EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + auto prim_py = dyn_cast(prim_); + if (prim_py == nullptr) { + MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; + } + + // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' + MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); + auto py_args = PreparePyInputs(prim_py, args); + prim_py->RunCheck(py_args); + + prim_->BeginRecordAddAttr(); + AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + + if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { + return std::make_shared(abs_base, std::make_shared(added_attrs)); + } + + // Call method 'infer_value' for primitive with this method for constant propagation + py::tuple py_vals(py_args.size()); + for (size_t i = 0; i < py_args.size(); ++i) { + py_vals[i] = py_args[i][ATTR_VALUE]; + } + py::object py_ret = prim_py->RunInferValue(py_vals); + if (py::isinstance(py_ret)) { + return std::make_shared(abs_base, std::make_shared(added_attrs)); + } + + // Convert pyobject to Value, then to AbstractValue + ValuePtr converted_ret = nullptr; + TypePtr dtype = abs_base->BuildType(); + bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype); + if (!converted) { + MS_LOG(EXCEPTION) << "Convert data failed"; + } + auto res_spec = FromValue(converted_ret); + MS_EXCEPTION_IF_NULL(res_spec); + if (res_spec->isa()) { + // Replace to tensor constant node in specialize + auto res_tensor = res_spec->cast(); + res_tensor->set_value(converted_ret); + } + return std::make_shared(res_spec, std::make_shared(added_attrs)); +} + +EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } + + if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { + return EvalPyCheckPrim(engine, args); + } + + prim_->BeginRecordAddAttr(); + AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + return std::make_shared(abs_base, std::make_shared(added_attrs)); +} + EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 8d5aff93052..477d9ba861d 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { std::string ToString() const override { return identifier_ + prim_->name(); } private: + EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args); + PrimitivePtr prim_; const StandardPrimitiveEvalImpl eval_impl_; }; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 78d6a563c70..d5f81cac100 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -308,20 +308,18 @@ void AnalysisEngine::Clear() { namespace { EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { // Custom Primitive with python infer_shape, infer_type - EvaluatorPtr evaluator = nullptr; MS_EXCEPTION_IF_NULL(prim); if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; + return std::make_shared(prim); } if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; + return std::make_shared(prim); } if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { - evaluator = std::make_shared(prim); - return evaluator; + return std::make_shared(prim); } + + EvaluatorPtr evaluator = nullptr; if (prim->HasPyEvaluator()) { auto prim_py = dyn_cast(prim); if (prim_py != nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 7cf8071ec3a..6046666add0 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; return; } + if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) { + MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; + return; + } if (prim->name() == "fake_bprop") { MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index f26093ef031..d4b236d23a9 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) { if (!HasPyObj()) { MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; } - auto infer_fuc = python_obj_.attr("__infer__"); + auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER); return infer_fuc(*args); } +void PrimitivePy::RunCheck(const py::tuple &args) { + if (!HasPyObj()) { + MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; + } + auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK); + (void)check_func(*args); +} + +py::object PrimitivePy::RunInferValue(const py::tuple &args) { + if (!HasPyObj()) { + MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; + } + auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE); + return infer_value(*args); +} + REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) .value("py_infer_shape", PrimType::kPrimTypePyInferShape) - .value("user_custom", PrimType::kPrimTypeUserCustom); + .value("user_custom", PrimType::kPrimTypeUserCustom) + .value("py_infer_check", PrimType::kPrimTypePyInferCheck); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) .def(py::init()) diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 761db9142e4..3a19f0f43dc 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -62,6 +62,8 @@ class PrimitivePy : public Primitive { const bool parse_info_ = true; const py::object &GetPyObj() const { return python_obj_; } py::dict RunInfer(const py::tuple &args); + void RunCheck(const py::tuple &args); + py::object RunInferValue(const py::tuple &args); bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } bool HasPyObj() { return python_obj_.operator bool(); } PrimitivePtr Clone() override; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index d73526d42c0..19bcd3fa8e6 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 136a7b5075f..fb7684eb980 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -14,6 +14,8 @@ * limitations under the License. */ +#include +#include #include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" @@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt // outputs: dx return std::make_shared(ids->element(), ids_idx->shape()); } + +AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractTensorPtr params = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr indices = CheckArg(op_name, args_spec_list, 1); + AbstractScalarPtr axis = CheckArg(op_name, args_spec_list, 2); + + auto params_shp = params->shape()->shape(); + auto indices_shp = indices->shape()->shape(); + auto axis_val = GetValue(axis->BuildValue()); + + auto params_rank = static_cast(params_shp.size()); + if (axis_val < 0) { + axis_val += params_rank; + } + + auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector { + ShapeVector out_vec; + std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec)); + copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec)); + copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec)); + return out_vec; + }; + + ShapeVector out_shape = calc_shape(indices_shp); + if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) { + ShapeVector min_shape = calc_shape(indices->shape()->min_shape()); + ShapeVector max_shape = calc_shape(indices->shape()->max_shape()); + return std::make_shared(params->element(), + std::make_shared(out_shape, min_shape, max_shape)); + } + + return std::make_shared(params->element(), std::make_shared(out_shape)); +} + +AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + auto shape = input->shape()->shape(); + + bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == Shape::SHP_ANY; }); + std::vector tensor_shp({static_cast(shape.size())}); + if (has_dyn_shape) { + auto elem = std::make_shared(std::make_shared(), std::make_shared(32)); + return std::make_shared(elem, std::make_shared(tensor_shp)); + } + auto shp_buf_size = sizeof(int) * shape.size(); + auto tensor = std::make_shared(kNumberTypeInt32, tensor_shp, shape.data(), shp_buf_size); + + return tensor->ToAbstract(); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index f0da9535b01..1a3d8f6c712 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive return std::make_shared(AbstractBasePtrList({dx, dy})); } + +AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto inp = CheckArg(op_name, args_spec_list, 0); + return inp->Clone()->Broaden(); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 625c07fb904..ad8691f087a 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti return std::make_shared(std::make_shared(kAnyValue, kUInt8), std::make_shared(std::vector{shape_y})); } + +AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + CheckArgsSize(primitive->name(), args_spec_list, 5); + AbstractBasePtrList elements; + for (size_t i = 0; i < 3; ++i) { + elements.push_back(args_spec_list[i]->Clone()->Broaden()); + } + return std::make_shared(elements); +} + +AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + CheckArgsSize(primitive->name(), args_spec_list, 7); + AbstractBasePtrList elements; + for (size_t i = 0; i < 2; ++i) { + elements.push_back(args_spec_list[i]->Clone()->Broaden()); + } + return std::make_shared(elements); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 6bf6d7f3473..dfeda677123 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { // Maths {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimSqrt, {InferImplSqrt, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, @@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimPack, {InferImplPack, true}}, {prim::kPrimUnique, {InferImplUnique, true}}, {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, + {prim::kPrimGatherV2, {InferImplGatherV2, true}}, + {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, + {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, @@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, + {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, + {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, // Others {prim::kPrimIdentity, {InferImplIdentity, true}}, // Set impl to null as it will use PartialEvaluator; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 438a9e4ea87..9491ec39a00 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); inline const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); inline const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared("SparseGatherV2"); +inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); +inline const PrimitivePtr kPrimDynamicShape = std::make_shared("DynamicShape"); inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); inline const PrimitivePtr kPrimSize = std::make_shared("Size"); @@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); +inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared("SparseApplyFtrl"); +inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared("SparseApplyProximalAdagrad"); // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index c6525dabd67..9ad5c3adf74 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -35,7 +35,8 @@ enum PrimType { kPrimTypeBuiltIn, // Built-in primitive operator kPrimTypePyInferShape, // Primitive operator defined by custom kPrimTypePyInferTensor, // Primitive operator defined by custom - kPrimTypeUserCustom + kPrimTypeUserCustom, + kPrimTypePyInferCheck // Primitive operator with input args checking method }; class Primitive : public Named { diff --git a/mindspore/core/utils/flags.cc b/mindspore/core/utils/flags.cc index a36d0367d66..ba4324883eb 100644 --- a/mindspore/core/utils/flags.cc +++ b/mindspore/core/utils/flags.cc @@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; + +// method names of python primitive called from c++ source code +// 1. infer method name of class 'PrimitiveWithInfer' +const char PY_PRIM_METHOD_INFER[] = "__infer__"; +// 2. check method name of class 'PrimitiveWithCheck' +const char PY_PRIM_METHOD_CHECK[] = "__check__"; +// 3. method name of class 'PrimitivePy' for constant propagation +const char PY_PRIM_METHOD_INFER_VALUE[] = "infer_value"; + +// type inference related attributes +const char ATTR_VALUE[] = "value"; +const char ATTR_DTYPE[] = "dtype"; +const char ATTR_SHAPE[] = "shape"; +const char ATTR_MIN_SHAPE[] = "min_shape"; +const char ATTR_MAX_SHAPE[] = "max_shape"; } // namespace mindspore diff --git a/mindspore/core/utils/flags.h b/mindspore/core/utils/flags.h index 89268fbaed9..0ab5d195c9d 100644 --- a/mindspore/core/utils/flags.h +++ b/mindspore/core/utils/flags.h @@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; extern const char GRAPH_FLAG_SIDE_EFFECT[]; + +extern const char PY_PRIM_METHOD_INFER[]; +extern const char PY_PRIM_METHOD_CHECK[]; +extern const char PY_PRIM_METHOD_INFER_VALUE[]; + +extern const char ATTR_VALUE[]; +extern const char ATTR_DTYPE[]; +extern const char ATTR_SHAPE[]; +extern const char ATTR_MIN_SHAPE[]; +extern const char ATTR_MAX_SHAPE[]; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_FLAGS_H diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index aa8ba107d30..7a6c749b6b5 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, + Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, @@ -206,6 +206,7 @@ __all__ = [ 'HookBackward', 'InvertPermutation', 'Shape', + 'DynamicShape', 'DropoutDoMask', 'DropoutGenMask', 'DropoutGrad', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index c8d21e4dd25..8f0a9ba4270 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -27,7 +27,7 @@ import numpy as np from .._utils import get_concat_offset from ..operations.math_ops import _infer_shape_reduce -from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op +from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op from ..._c_expression import signature_dtype as sig_dtype from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_rw as sig_rw @@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer): out = {'shape': x_shape, 'dtype': x['dtype'], 'value': value} + if 'min_shape' in x and 'max_shape' in x: + out['min_shape'] = x['min_shape'] + out['min_shape'].insert(axis_v, 1) + out['max_shape'] = x['max_shape'] + out['max_shape'].insert(axis_v, 1) return out @@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer): out = {'shape': x['shape'], 'dtype': mstype.tensor_type(t['value']), 'value': value} + if 'min_shape' in x and 'max_shape' in x: + out['min_shape'] = x['min_shape'] + out['max_shape'] = x['max_shape'] return out @@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer): return out +class DynamicShape(Primitive): + """ + Returns the shape of input tensor. + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Outputs: + Tensor[int], 1-dim Tensor of type int32 + + Examples: + >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32) + >>> shape = P.DynamicShape() + >>> output = shape(input_tensor) + """ + + @prim_attr_register + def __init__(self): + """init Shape""" + + class Squeeze(PrimitiveWithInfer): """ Returns a tensor with the same type but dimensions of 1 being removed based on axis. @@ -578,7 +607,7 @@ class Unique(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) -class GatherV2(PrimitiveWithInfer): +class GatherV2(PrimitiveWithCheck): """ Returns a slice of input tensor based on the specified indices and axis. @@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer): """init index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - def __infer__(self, params, indices, axis): + def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) @@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer): params_shp = params['shape'] rank = len(params_shp) validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) - if axis_v < 0: - axis_v += rank - out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - return out class SparseGatherV2(GatherV2): diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 0890758f445..ad2cc1c6fc8 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -26,7 +26,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from .._utils import get_broadcast_shape -from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op +from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op def _infer_shape_reduce(x, axis, keep_dims, prim_name): @@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer): return None -class Sqrt(PrimitiveWithInfer): +class Sqrt(PrimitiveWithCheck): """ Returns square root of a tensor element-wise. @@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer): """init Sqrt""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_type): + def check_dtype(self, x_type): validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) - return x_type def infer_value(self, x): if x is not None: diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index d6de700bf61..2c94af7847f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype -from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register from ..operations.math_ops import _infer_shape_reduce @@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): return var_dtype, accum_dtype -class SparseApplyProximalAdagrad(PrimitiveWithInfer): +class SparseApplyProximalAdagrad(PrimitiveWithCheck): r""" Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, an additional index tensor is input. @@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): outputs=['var', 'accum']) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): + def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) - return var_shape, accum_shape - def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): + def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) @@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): valid_types = [mstype.int16, mstype.int32, mstype.int64, mstype.uint16, mstype.uint32, mstype.uint64] validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) - return var_dtype, accum_dtype class ApplyAddSign(PrimitiveWithInfer): @@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer): return var_type -class SparseApplyFtrl(PrimitiveWithInfer): +class SparseApplyFtrl(PrimitiveWithCheck): """ Update relevant entries according to the FTRL-proximal scheme. @@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer): self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): + def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) if len(var_shape) > 1: validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) - return var_shape, accum_shape, linear_shape - def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): + def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) - return var_dtype, accum_dtype, linear_dtype class SparseApplyFtrlV2(PrimitiveWithInfer): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 7b2596a885f..963c9762ba3 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -200,6 +200,84 @@ class Primitive(Primitive_): return self._update_parameter +class PrimitiveWithCheck(Primitive): + """ + PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments + but used the infer method registed in c++ source codes. + + There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(), + check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. + If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of + the shape and type. + + Args: + name (str): Name of the current Primitive. + + Examples: + >>> # init a Primitive class with check + >>> class Flatten(PrimitiveWithCheck): + >>> @prim_attr_register + >>> def __init__(self): + >>> pass + >>> def check_shape(self, input_x): + >>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) + >>> + >>> def check_dtype(self, input_x): + >>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name) + >>> + >>> # init a Primitive obj + >>> add = Flatten() + """ + + def __init__(self, name): + Primitive.__init__(self, name) + self.set_prim_type(prim_type.py_infer_check) + + def _clone(self): + """ + Deeply clones the primitive object. + + Calls the __init__() method with the same arguments. This method is called in parser if the + flag self.__setattr_flag__ is True. + """ + cloned_prim = Primitive._clone(self) + return cloned_prim + + def check_shape(self, *args): + """ + Check shapes of input args. + + Note: + The shape of scalar is an empty tuple. + + Args: + args (tuple(int)): shapes of input tensors. + + Return: + None. + """ + return None + + def check_dtype(self, *args): + """ + Check data types of input args. + + Args: + args (:class:`mindspore.dtype`): data type of inputs. + + Return: + None. + """ + return None + + def __check__(self, *args): + """Check shape, type, and value at the same time by using dictionary as arguments.""" + tracks = ['dtype', 'shape'] + for track in tracks: + fn = getattr(self, 'check_' + track) + fn(*(x[track] for x in args)) + + class PrimitiveWithInfer(Primitive): """ PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python. @@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive): if not is_graph_mode: return out + # output does not contain dynamic shape, no need to calculate min/max shape + def has_dynamic_shape(shp): + if isinstance(shp, int): + return shp < 0 + if isinstance(shp, (list, tuple)): + return any(has_dynamic_shape(e) for e in shp) + return False + + if not has_dynamic_shape(out['shape']): + return out + + # calculate min/max shape for output def get_specified_shape(elems, attr): has_specified_shape = False ret_vals = [] @@ -345,6 +435,8 @@ def prim_attr_register(fn): def deco(self, *args, **kwargs): if isinstance(self, PrimitiveWithInfer): PrimitiveWithInfer.__init__(self, self.__class__.__name__) + elif isinstance(self, PrimitiveWithCheck): + PrimitiveWithCheck.__init__(self, self.__class__.__name__) else: Primitive.__init__(self, self.__class__.__name__) bound_args = inspect.signature(fn).bind(self, *args, **kwargs) diff --git a/tests/ut/python/ir/test_row_tensor.py b/tests/ut/python/ir/test_row_tensor.py index 62d7d761a1c..efc27302f05 100644 --- a/tests/ut/python/ir/test_row_tensor.py +++ b/tests/ut/python/ir/test_row_tensor.py @@ -27,7 +27,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register from mindspore.ops._grad.grad_base import bprop_getters from mindspore import Tensor, RowTensor, context from mindspore.common.parameter import Parameter, ParameterTuple @@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis): perm = index[1:1 + axis] + (0,) + index[1 + axis:] return perm -class MySparseGatherV2(P.GatherV2): +# pylint: disable=W0231 +class MySparseGatherV2(PrimitiveWithInfer): """ For test """ + @prim_attr_register + def __init__(self): + """init index_select""" + self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + + def __infer__(self, params, indices, axis): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) + axis_v = axis['value'] + params_shp = params['shape'] + rank = len(params_shp) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) + if axis_v < 0: + axis_v += rank + out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out @bprop_getters.register(MySparseGatherV2) def get_bprop_sparse_gather_v2(self): diff --git a/tests/ut/python/ops/test_dynamic_shape.py b/tests/ut/python/ops/test_dynamic_shape.py new file mode 100755 index 00000000000..12dabb618d5 --- /dev/null +++ b/tests/ut/python/ops/test_dynamic_shape.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test dynamic shape """ +from mindspore import Tensor, context, nn, Parameter +from mindspore.ops import operations as P +from mindspore import dtype as mstype + +import numpy as np + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + + +def test_sparse_apply_proximal_ada_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") + self.lr = 0.01 + self.l1 = 0.0 + self.l2 = 0.0 + def construct(self, grad, indices): + out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) + return out[0] + + class NetWrapper(nn.Cell): + def __init__(self): + super(NetWrapper, self).__init__() + self.unq = P.Unique() + self.add = P.TensorAdd() + self.expand_dims = P.ExpandDims() + self.cast = P.Cast() + self.net = Net() + + def construct(self, grad, inp): + ids, _ = self.unq(inp) + new_grad = self.expand_dims(ids, 1) + new_grad = self.cast(new_grad, mstype.float32) + grad + return self.net(new_grad, ids) + + net = NetWrapper() + grad = Tensor(np.random.rand(1, 80).astype(np.float32)) + indices = Tensor(np.ones([7800]), mstype.int32) + net(grad, indices) + + +def test_sparse_apply_ftrl(): + class SparseApplyFtrlNet(nn.Cell): + def __init__(self): + super(SparseApplyFtrlNet, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) + self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum") + self.linear = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="linear") + + def construct(self, grad, indices): + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + return out[0] + + class NetWrapper(nn.Cell): + def __init__(self): + super(NetWrapper, self).__init__() + self.unq = P.Unique() + self.add = P.TensorAdd() + self.expand_dims = P.ExpandDims() + self.cast = P.Cast() + self.net = SparseApplyFtrlNet() + + def construct(self, grad, inp): + ids, _ = self.unq(inp) + new_grad = self.expand_dims(ids, 1) + new_grad = self.cast(new_grad, mstype.float32) + grad + return self.net(new_grad, ids) + + net = NetWrapper() + grad = Tensor(np.random.rand(1, 80).astype(np.float32)) + indices = Tensor(np.ones([7800]), mstype.int32) + net(grad, indices) + + +def test_gatherv2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.unq = P.Unique() + self.gather = P.GatherV2() + + def construct(self, x, y): + u, _ = self.unq(y) + z = self.gather(x, u, 0) + return z + + x = Tensor(np.ones([20, 12], dtype=np.float32)) + y = Tensor(np.ones([8], dtype=np.int32)) + net = Net() + net(x, y)