From 1d65ae598a96c78c81856b99b54ce09da54d4d95 Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 26 May 2020 17:32:33 +0800 Subject: [PATCH] extract const_to_attr_strided_slice_grad pass --- .../common/common_backend_optimization.cc | 2 + mindspore/ccsrc/pre_activate/common/helper.cc | 2 +- .../pass/const_input_to_attr_registry.cc | 1 - .../pass/const_to_attr_strided_slice_grad.cc | 132 ++++++++++++++++++ .../pass/const_to_attr_strided_slice_grad.h | 34 +++++ mindspore/ccsrc/utils/utils.h | 2 + .../const_to_attr_strided_slice_grad_test.cc | 77 ++++++++++ .../pass/convert_const_input_to_attr_test.cc | 39 ------ .../const_to_attr_strided_slice_grad.py | 46 ++++++ .../pre_activate/convert_const_input_test.py | 15 -- 10 files changed, 294 insertions(+), 56 deletions(-) create mode 100644 mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc create mode 100644 mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h create mode 100644 tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/pre_activate/const_to_attr_strided_slice_grad.py diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc index 03833111222..7ba42a60a0d 100644 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc @@ -21,6 +21,7 @@ #include "pre_activate/pass/convert_tuple_output_to_maketuple.h" #include "pre_activate/pass/convert_const_input_to_tensor_input.h" #include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" +#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" #include "utils/context/ms_context.h" #include "debug/anf_ir_dump.h" @@ -42,6 +43,7 @@ void BackendCommonOptimization(const std::shared_ptr &kern auto optimizer = std::make_shared(); auto common_pm = std::make_shared("common_pm"); common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 4cda390fbb2..9be537775e1 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -687,7 +687,7 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va MS_EXCEPTION_IF_NULL(equiv1_node); auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); MS_EXCEPTION_IF_NULL(equiv2_node); - return equiv1_node == equiv2_node; + return *equiv1_node == *equiv2_node; } AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index 08bbb351377..a8bd55f125d 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kScatterNdOpName, {2}); Register(kStridedSliceAssignOpName, {1, 2, 3}); Register(kStridedSliceOpName, {1, 2, 3}); - Register(kStridedSliceGradOpName, {1, 2, 3, 4}); Register(kFlattenGradOpName, {1}); Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc new file mode 100644 index 00000000000..68f54c1636f --- /dev/null +++ b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t strides_index = 5; + +bool GetStridesValues(const CNodePtr &strided_slice_grad, ValuePtrList *strides_values) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (strided_slice_grad->size() < 6) { + MS_LOG(DEBUG) << "Op strided_slice_grad's inputs size less than 6, graph not changed"; + return false; + } + auto strides_input = strided_slice_grad->input(strides_index); + MS_EXCEPTION_IF_NULL(strides_input); + auto strides_value_node = strides_input->cast(); + if (strides_value_node == nullptr) { + MS_LOG(DEBUG) << "strides is not a value node."; + return false; + } + auto value = strides_value_node->value(); + if (value == nullptr) { + MS_LOG(DEBUG) << "strides has no value."; + return false; + } + auto value_tuple = value->cast(); + if (value_tuple == nullptr) { + MS_LOG(DEBUG) << "strides is not a value tuple."; + return false; + } + *strides_values = value_tuple->value(); + return true; +} + +bool CheckValues(const ValuePtrList &strides_values) { + if (strides_values.empty()) { + MS_LOG(DEBUG) << "strides_values is empty"; + return false; + } + for (auto &value : strides_values) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto scalar = value->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (!scalar->isa()) { + MS_LOG(DEBUG) << "strides value is not a Integer"; + return false; + } + if (GetValue(scalar) != 1) { + MS_LOG(DEBUG) << "StridedSliceGrad has no 1 value"; + return false; + } + } else { + MS_LOG(DEBUG) << "The value " << value << "of tuple is not a scalar"; + return false; + } + } + return true; +} + +bool CheckAttrs(const CNodePtr &strided_slice_grad) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) || + !AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]"; + return false; + } + auto new_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrNewAxisMask); + auto shrink_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrShrinkAxisMask); + if (new_axis_mask != 0 || shrink_axis_mask != 0) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0"; + return false; + } + return true; +} +} // namespace + +const BaseRef ConstToAttrStridedSliceGradPass::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); + return VectorRef({strided_slice_grad_prim, Xs}); +} + +const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto strided_slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_grad); + + if (!CheckAttrs(strided_slice_grad)) { + MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; + return nullptr; + } + + ValuePtrList strides_values; + if (!GetStridesValues(strided_slice_grad, &strides_values)) { + return nullptr; + } + + if (!CheckValues(strides_values)) { + MS_LOG(INFO) << "Check strides' values failed, graph not changed"; + return nullptr; + } + + ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4}); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h new file mode 100644 index 00000000000..2e364244bf7 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConstToAttrStridedSliceGradPass : public PatternProcessPass { + public: + explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) + : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} + ~ConstToAttrStridedSliceGradPass() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index bc06d61a67a..976fa848d4c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -193,6 +193,8 @@ constexpr auto kAttrIsTraining = "is_training"; constexpr auto kAttrFusionId = "fusion_id"; constexpr auto kAttrLabelIndex = "label_index"; constexpr auto kAttrLabelSwitchList = "label_switch_list"; +constexpr auto kAttrNewAxisMask = "new_axis_mask"; +constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc new file mode 100644 index 00000000000..8fc709433e0 --- /dev/null +++ b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "operator/ops.h" +#include "debug/anf_ir_dump.h" +#include "common/py_func_graph_fetcher.h" +#include "session/anf_runtime_algorithm.h" +#include "pre_activate/common/optimizer.h" +#include "pre_activate/common/pass_manager.h" +#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" +#include "utils/utils.h" +#include "common/utils.h" + +namespace mindspore { +namespace opt { +class TestHWConstToAttrStridedSliceGrad : public BackendCommon { + public: + TestHWConstToAttrStridedSliceGrad() : getPyFun_("gtest_input.pre_activate.const_to_attr_strided_slice_grad", true) {} + ~TestHWConstToAttrStridedSliceGrad() override = default; + + public: + UT::PyFuncGraphFetcher getPyFun_; +}; + +TEST_F(TestHWConstToAttrStridedSliceGrad, test_strided_slice_grad) { + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_const_to_attr_strided_slice_grad", "before"); + ASSERT_TRUE(g != nullptr); + FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_const_to_attr_strided_slice_grad", "after"); + ASSERT_TRUE(g_after != nullptr); + + auto ret = g->get_return(); + ASSERT_TRUE(ret != nullptr); + EXPECT_NE(ret->input(1), nullptr); + EXPECT_NE(ret->input(1)->cast(), nullptr); + auto cnode = ret->input(1)->cast(); + EXPECT_FALSE(AnfAlgo::HasNodeAttr("shapex", cnode)); + EXPECT_FALSE(AnfAlgo::HasNodeAttr("begin", cnode)); + EXPECT_FALSE(AnfAlgo::HasNodeAttr("end", cnode)); + EXPECT_FALSE(AnfAlgo::HasNodeAttr("strides", cnode)); + EXPECT_FALSE(CheckEqualGraph(g, g_after)); + + std::vector shp_x{16, 1, 1024}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + ASSERT_TRUE(kg != nullptr); + + ret = kg->get_return(); + ASSERT_TRUE(ret != nullptr); + EXPECT_NE(ret->input(1), nullptr); + EXPECT_NE(ret->input(1)->cast(), nullptr); + auto make_tuple = ret->input(1)->cast(); + ASSERT_TRUE(make_tuple != nullptr); + EXPECT_NE(make_tuple->input(1), nullptr); + EXPECT_NE(make_tuple->input(1)->cast(), nullptr); + cnode = make_tuple->input(1)->cast(); + EXPECT_TRUE(AnfAlgo::HasNodeAttr("shapex", cnode)); + EXPECT_TRUE(AnfAlgo::HasNodeAttr("begin", cnode)); + EXPECT_TRUE(AnfAlgo::HasNodeAttr("end", cnode)); + EXPECT_TRUE(AnfAlgo::HasNodeAttr("strides", cnode)); + EXPECT_TRUE(CheckEqualGraph(kg, g_after)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc index 99130efd5df..fcb3b19a249 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc @@ -109,44 +109,5 @@ TEST_F(TestHWConstInputToAttr, test_onehot) { EXPECT_TRUE(AnfAlgo::HasNodeAttr("depth", cnode)); EXPECT_TRUE(CheckEqualGraph(func_graph, g_after)); } - -TEST_F(TestHWConstInputToAttr, test_strided_slice_grad) { - FuncGraphPtr g = getPyFun_.CallAndParseRet("test_convert_strided_slice_grad_input_to_attr", "before"); - ASSERT_TRUE(g != nullptr); - FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_convert_strided_slice_grad_input_to_attr", "after"); - ASSERT_TRUE(g_after != nullptr); - - auto ret = g->get_return(); - ASSERT_TRUE(ret != nullptr); - EXPECT_NE(ret->input(1), nullptr); - EXPECT_NE(ret->input(1)->cast(), nullptr); - auto cnode = ret->input(1)->cast(); - EXPECT_FALSE(AnfAlgo::HasNodeAttr("shapex", cnode)); - EXPECT_FALSE(AnfAlgo::HasNodeAttr("begin", cnode)); - EXPECT_FALSE(AnfAlgo::HasNodeAttr("end", cnode)); - EXPECT_FALSE(AnfAlgo::HasNodeAttr("strides", cnode)); - EXPECT_FALSE(CheckEqualGraph(g, g_after)); - - std::vector shp_x{16, 1, 1024}; - auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract}; - auto func_graph = GetKernelGraph(g, args_spec_list); - ASSERT_TRUE(func_graph != nullptr); - - ret = func_graph->get_return(); - ASSERT_TRUE(ret != nullptr); - EXPECT_NE(ret->input(1), nullptr); - EXPECT_NE(ret->input(1)->cast(), nullptr); - auto make_tuple = ret->input(1)->cast(); - ASSERT_TRUE(make_tuple != nullptr); - EXPECT_NE(make_tuple->input(1), nullptr); - EXPECT_NE(make_tuple->input(1)->cast(), nullptr); - cnode = make_tuple->input(1)->cast(); - EXPECT_TRUE(AnfAlgo::HasNodeAttr("shapex", cnode)); - EXPECT_TRUE(AnfAlgo::HasNodeAttr("begin", cnode)); - EXPECT_TRUE(AnfAlgo::HasNodeAttr("end", cnode)); - EXPECT_TRUE(AnfAlgo::HasNodeAttr("strides", cnode)); - EXPECT_TRUE(CheckEqualGraph(func_graph, g_after)); -} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/const_to_attr_strided_slice_grad.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/const_to_attr_strided_slice_grad.py new file mode 100644 index 00000000000..9abd0f1d539 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/const_to_attr_strided_slice_grad.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops.operations import _grad_ops as G + +stridedslicegrad = G.StridedSliceGrad() +backend_stridedslicegrad = Primitive('StridedSliceGrad') +make_tuple = Primitive('make_tuple') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_const_to_attr_strided_slice_grad(tag): + fns = FnDict() + + @fns + def before(x): + return stridedslicegrad(x, (16, 128, 1024), (0, 0, 0), (16, 1, 1024), (1, 1, 1)) + + @fns + def after(x): + res = backend_stridedslicegrad(x) + return make_tuple(res) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_const_input_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_const_input_test.py index ef2925826b5..02cb5c44886 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_const_input_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_const_input_test.py @@ -110,21 +110,6 @@ def test_convert_onehot_input_to_attr(tag): return fns[tag] -def test_convert_strided_slice_grad_input_to_attr(tag): - fns = FnDict() - - @fns - def before(x): - return stridedslicegrad(x, (16, 128, 1024), (0, 0, 0), (16, 1, 1024), (1, 1, 1)) - - @fns - def after(x): - res = backend_stridedslicegrad(x) - return make_tuple(res) - - return fns[tag] - - def test_convert_onehot_input_to_tensor1(tag): fns = FnDict()