From e3d7304208f6616e36e8e7903360cdc267f06310 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Tue, 18 May 2021 20:35:24 +0800 Subject: [PATCH] add function to remove monad abs when using cpp infer --- .../pipeline/jit/static_analysis/prim.cc | 4 +- mindspore/core/abstract/dshape.cc | 38 ++++++++++++++----- mindspore/core/abstract/infer_functions.h | 2 - mindspore/core/abstract/prim_others.cc | 14 ------- .../core/abstract/primitive_infer_map.cc | 1 - mindspore/core/ops/assign.cc | 20 +++++++++- mindspore/core/utils/check_convert_utils.cc | 15 ++++++++ mindspore/core/utils/check_convert_utils.h | 1 + mindspore/ops/operations/other_ops.py | 8 +--- tests/ut/cpp/abstract/dshape_test.cc | 3 ++ 10 files changed, 70 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 5bd49859b95..758594a1fb6 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -613,7 +613,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool need_infer_value = - !(eval_impl_.in_white_list_) || (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode); + (!(eval_impl_.in_white_list_) || (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode)) && + std::all_of(args.begin(), args.end(), + [](const AbstractBasePtr &abs) -> bool { return (abs->BuildValue() != nullptr); }); AbstractBasePtr abs_base = nullptr; ValuePtr value = nullptr; prim_->BeginRecordAddAttr(); diff --git a/mindspore/core/abstract/dshape.cc b/mindspore/core/abstract/dshape.cc index 6695fa7f0ff..a9d41c52126 100644 --- a/mindspore/core/abstract/dshape.cc +++ b/mindspore/core/abstract/dshape.cc @@ -25,6 +25,23 @@ namespace mindspore { namespace abstract { +namespace { +std::string ShapeVectorToStr(const std::vector &shp) { + std::ostringstream buffer; + bool f_begin = true; + buffer << "("; + for (auto &x : shp) { + if (!f_begin) { + buffer << ", "; + } else { + f_begin = false; + } + buffer << x; + } + buffer << ")"; + return buffer.str(); +} +} // namespace // used for print BaseShape content std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { os << bs.ToString(); @@ -48,17 +65,18 @@ bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == oth std::string Shape::ToString() const { std::ostringstream buffer; - bool f_begin = true; - buffer << "("; - for (auto &x : shape_) { - if (!f_begin) { - buffer << ", "; - } else { - f_begin = false; - } - buffer << x; + bool has_dyn_shape = IsDynamic(); + if (has_dyn_shape) { + buffer << "{ shape : "; + } + buffer << ShapeVectorToStr(shape_); + if (has_dyn_shape) { + buffer << " | min shape: "; + buffer << ShapeVectorToStr(min_shape_); + buffer << " | max shape: "; + buffer << ShapeVectorToStr(max_shape_); + buffer << " }"; } - buffer << ")"; return buffer.str(); } diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 70d4ebf0d04..ffba50327ad 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -286,8 +286,6 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); template diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index f3588ff7fde..32b1ebb0f27 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -562,19 +562,5 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri } return args_spec_list[0]->Broaden(); } - -AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: Ref, value, [universal] - CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); - - MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; - auto type = args_spec_list[0]->BuildType(); - if (type->type_id() == kObjectTypeRefKey) { - return args_spec_list[1]->Broaden(); - } else { - return args_spec_list[0]; - } -} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index ec7f3ca71dc..78c58bc72fd 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -136,7 +136,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { // Others {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, - {prim::kPrimAssign, {InferImplAssign, nullptr, true}}, // Set impl to null as it will use PartialEvaluator; {prim::kPrimPartial, {nullptr, nullptr, true}}, {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}}, diff --git a/mindspore/core/ops/assign.cc b/mindspore/core/ops/assign.cc index 1bfad53e91b..a828ff1bcd0 100644 --- a/mindspore/core/ops/assign.cc +++ b/mindspore/core/ops/assign.cc @@ -26,6 +26,24 @@ namespace mindspore { namespace ops { -REGISTER_PRIMITIVE_C(kNameAssign, Assign); +AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + (void)CheckAndConvertUtils::CheckInteger("Assign infer", (CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), + kEqual, 2, prim_name); + auto check_types = common_valid_types; + check_types.emplace(kBool); + auto variable_type = args_spec_list[0]->BuildType(); + auto value_type = args_spec_list[1]->BuildType(); + CheckAndConvertUtils::CheckScalarOrTensorTypesSame(std::map{{"value", value_type}}, check_types, + prim_name); + if (variable_type->isa()) { + return args_spec_list[1]->Broaden(); + } + CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name); + return args_spec_list[0]; +} +REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index c206233d8a2..224bbf11328 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -643,4 +643,19 @@ int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) { } return data_format; } +int64_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) { + int64_t remove_monad_count = abs_list.size(); + for (const auto &item : abs_list) { + if (item->isa()) { + --remove_monad_count; + } + } + + for (int64_t i = 0; i < remove_monad_count; ++i) { + if (abs_list[i]->isa()) { + MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs."; + } + } + return remove_monad_count; +} } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 72c89b93d14..c81dbe75b5f 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -323,6 +323,7 @@ class CheckAndConvertUtils { const std::string &arg_name); static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); static int64_t GetAndCheckFormat(const ValuePtr &value); + static int64_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list); private: static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 14d6df3dd79..2a2e83cd142 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -23,7 +23,7 @@ from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register -class Assign(PrimitiveWithCheck): +class Assign(Primitive): """ Assigns `Parameter` with a value. @@ -73,12 +73,6 @@ class Assign(PrimitiveWithCheck): self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) self.add_prim_attr('side_effect_mem', True) - def check_dtype(self, variable, value): - types = mstype.number_type + (mstype.bool_,) - if variable != mstype.type_refkey: - validator.check_tensor_dtype_valid("variable", variable, types, self.name) - validator.check_scalar_or_tensor_types_same({"value": value}, types, self.name) - class InplaceAssign(PrimitiveWithInfer): """ diff --git a/tests/ut/cpp/abstract/dshape_test.cc b/tests/ut/cpp/abstract/dshape_test.cc index da0e9ed3eef..a6f92b9cecb 100644 --- a/tests/ut/cpp/abstract/dshape_test.cc +++ b/tests/ut/cpp/abstract/dshape_test.cc @@ -29,6 +29,7 @@ class TestDShape : public UT::Common { Shape shp_2; Shape shp_3; Shape shp_4; + Shape shp_5; NoShape shp_noshp_1; NoShape shp_noshp_2; @@ -42,6 +43,7 @@ class TestDShape : public UT::Common { shp_2({1, 1}), shp_3({1, 2}), shp_4({1}), + shp_5({-1, 2}, {1, 2}, {3, 3}), shp_noshp_1(), shp_noshp_2(), @@ -67,6 +69,7 @@ TEST_F(TestDShape, ToString) { ASSERT_EQ(shp_3.ToString(), "(1, 2)"); ASSERT_EQ(shp_noshp_1.ToString(), "NoShape"); ASSERT_EQ(shp_tuple_2.ToString(), "TupleShape(NoShape, (1, 1, 1))"); + ASSERT_EQ(shp_5.ToString(), "{ shape : (-1, 2) | min shape: (1, 2) | max shape: (3, 3) }"); } TEST_F(TestDShape, Clone) {