From 5b91ce9b64bbe15b7164944e471384f2d4d0febe Mon Sep 17 00:00:00 2001 From: reku1997 Date: Mon, 6 Mar 2023 15:10:46 +0800 Subject: [PATCH] fix convert dynamic broadcast to pass --- .../pass/convert_dynamic_broadcast_to.cc | 11 +-- .../convert_dynamic_broadcast_to_test.cc | 71 +++++++++++++++++++ .../convert_dynamic_broadcast_to_test.py | 48 +++++++++++++ 3 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/ut/cpp/pre_activate/ascend/mindir/convert_dynamic_broadcast_to_test.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/pre_activate/convert_dynamic_broadcast_to_test.py diff --git a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc index 4cae1f463dd..50f60a0c347 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc @@ -23,10 +23,11 @@ namespace mindspore { namespace opt { namespace { -const auto kV = "V"; +const auto kA = "A"; +const auto kVs = "Vs"; const auto kMBroadcastTo = "m_broadcast_to"; const auto kRBroadcastTo = "r_broadcast_to"; -AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) { +AnfNodePtr BuildBroadcastTo(const PatternMap &m, const AnfNodePtr &) { auto node = m.Get(kMBroadcastTo); MS_EXCEPTION_IF_NULL(node); auto broadcast_to_op_name = prim::kPrimBroadcastTo->name(); @@ -37,6 +38,8 @@ AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) { CNodePtr broadcast_to_node = opt::NewCNode({NewValueNode(std::make_shared(broadcast_to_op_name)), input_x}, func_graph, {node}); MS_EXCEPTION_IF_NULL(broadcast_to_node); + MS_EXCEPTION_IF_NULL(node->abstract()); + MS_EXCEPTION_IF_NULL(node->abstract()->BuildShape()); broadcast_to_node->set_abstract(node->abstract()); auto shape_ptr = node->abstract()->BuildShape()->cast(); MS_EXCEPTION_IF_NULL(shape_ptr); @@ -55,11 +58,11 @@ bool ConvertDynamicBroadcastTo::CheckMatchedDAG(const PatternMap &, const FuncGr } void ConvertDynamicBroadcastTo::DefineSrcPattern(SrcPattern *src_pattern) { - (*src_pattern).AddVar(kV).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}); + (*src_pattern).AddVar(kA).AddSeqVar(kVs).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kA, kVs}); } void ConvertDynamicBroadcastTo::DefineDstPattern(DstPattern *dst_pattern) { - (*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}, BuildDynamicBroadcastTo); + (*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimBroadcastTo, kA}, BuildBroadcastTo); } } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/mindir/convert_dynamic_broadcast_to_test.cc b/tests/ut/cpp/pre_activate/ascend/mindir/convert_dynamic_broadcast_to_test.cc new file mode 100644 index 00000000000..1b083b67ecb --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/mindir/convert_dynamic_broadcast_to_test.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "frontend/operator/ops.h" +#include "include/common/debug/anf_ir_dump.h" +#include "common/py_func_graph_fetcher.h" +#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pass_manager.h" +#include "utils/ms_utils.h" +#include "backend/common/pass/convert_dynamic_broadcast_to.h" + +namespace mindspore { +namespace opt { +namespace { +class TestConvertDynmicBroadcastToPass : public BackendCommon { + public: + TestConvertDynmicBroadcastToPass() : getPyFun_("gtest_input.pre_activate.convert_dynamic_broadcast_to_test", true) {} + ~TestConvertDynmicBroadcastToPass() override = default; + + public: + UT::PyFuncGraphFetcher getPyFun_; +}; + +/// Feature: ConvertDynmicBroadcastTo Pass +/// Description: ConvertDynmicBroadcastTo rewrite graph +/// Expectation: Get correct Graph +TEST_F(TestConvertDynmicBroadcastToPass, TestConvert) { + // build func graph + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_dyn_broadcast", "before"); + std::vector shpx{3}; + auto x_abstract = std::make_shared(kFloat32, shpx); + std::vector shpy{2, 3}; + auto y_abstract = std::make_shared(kInt64, shpy); + AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; + auto fg = GetFuncGraph(g, args_spec_list); + + bool has_dyn = false; + for (const auto &n : TopoSort(fg->get_return())) { + if (IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo)) { + has_dyn = true; + } + } + ASSERT_TRUE(has_dyn); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + for (const auto &n : TopoSort(fg->get_return())) { + ASSERT_FALSE(IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo)); + } +} +} // namespace +} // namespace opt +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_dynamic_broadcast_to_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_dynamic_broadcast_to_test.py new file mode 100644 index 00000000000..59bfe59f334 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_dynamic_broadcast_to_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.context as context +import mindspore.ops.operations as ops +from mindspore.ops.operations import _inner_ops as inner + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class FnDict: + def __init__(self): + self.fn_dict = {} + + def __call__(self, fn): + self.fn_dict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fn_dict.get(name) + + +def test_dyn_broadcast(tag): + """ + Feature: ConvertDynmicBroadcastTo Pass + Description: ConvertDynmicBroadcastTo rewrite graph. + Expectation: Get correct Graph. + """ + fns = FnDict() + d_shape = ops.TensorShape() + d_broadcastto = inner.DynamicBroadcastTo() + + @fns + def before(data, shape): + shape = d_shape(shape) + return d_broadcastto(data, shape) + + return fns[tag]