fix convert dynamic broadcast to pass

This commit is contained in:
reku1997 2023-03-06 15:10:46 +08:00
parent 1bfe24bd98
commit 5b91ce9b64
3 changed files with 126 additions and 4 deletions

View File

@ -23,10 +23,11 @@
namespace mindspore {
namespace opt {
namespace {
const auto kV = "V";
const auto kA = "A";
const auto kVs = "Vs";
const auto kMBroadcastTo = "m_broadcast_to";
const auto kRBroadcastTo = "r_broadcast_to";
AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
AnfNodePtr BuildBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
auto node = m.Get(kMBroadcastTo);
MS_EXCEPTION_IF_NULL(node);
auto broadcast_to_op_name = prim::kPrimBroadcastTo->name();
@ -37,6 +38,8 @@ AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
CNodePtr broadcast_to_node =
opt::NewCNode({NewValueNode(std::make_shared<Primitive>(broadcast_to_op_name)), input_x}, func_graph, {node});
MS_EXCEPTION_IF_NULL(broadcast_to_node);
MS_EXCEPTION_IF_NULL(node->abstract());
MS_EXCEPTION_IF_NULL(node->abstract()->BuildShape());
broadcast_to_node->set_abstract(node->abstract());
auto shape_ptr = node->abstract()->BuildShape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
@ -55,11 +58,11 @@ bool ConvertDynamicBroadcastTo::CheckMatchedDAG(const PatternMap &, const FuncGr
}
void ConvertDynamicBroadcastTo::DefineSrcPattern(SrcPattern *src_pattern) {
(*src_pattern).AddVar(kV).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV});
(*src_pattern).AddVar(kA).AddSeqVar(kVs).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kA, kVs});
}
void ConvertDynamicBroadcastTo::DefineDstPattern(DstPattern *dst_pattern) {
(*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}, BuildDynamicBroadcastTo);
(*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimBroadcastTo, kA}, BuildBroadcastTo);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,71 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "frontend/operator/ops.h"
#include "include/common/debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass_manager.h"
#include "utils/ms_utils.h"
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
namespace mindspore {
namespace opt {
namespace {
class TestConvertDynmicBroadcastToPass : public BackendCommon {
public:
TestConvertDynmicBroadcastToPass() : getPyFun_("gtest_input.pre_activate.convert_dynamic_broadcast_to_test", true) {}
~TestConvertDynmicBroadcastToPass() override = default;
public:
UT::PyFuncGraphFetcher getPyFun_;
};
/// Feature: ConvertDynmicBroadcastTo Pass
/// Description: ConvertDynmicBroadcastTo rewrite graph
/// Expectation: Get correct Graph
TEST_F(TestConvertDynmicBroadcastToPass, TestConvert) {
// build func graph
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_dyn_broadcast", "before");
std::vector<int64_t> shpx{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shpx);
std::vector<int64_t> shpy{2, 3};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shpy);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto fg = GetFuncGraph(g, args_spec_list);
bool has_dyn = false;
for (const auto &n : TopoSort(fg->get_return())) {
if (IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo)) {
has_dyn = true;
}
}
ASSERT_TRUE(has_dyn);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<ConvertDynamicBroadcastTo>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
for (const auto &n : TopoSort(fg->get_return())) {
ASSERT_FALSE(IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo));
}
}
} // namespace
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,48 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.context as context
import mindspore.ops.operations as ops
from mindspore.ops.operations import _inner_ops as inner
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name)
def test_dyn_broadcast(tag):
"""
Feature: ConvertDynmicBroadcastTo Pass
Description: ConvertDynmicBroadcastTo rewrite graph.
Expectation: Get correct Graph.
"""
fns = FnDict()
d_shape = ops.TensorShape()
d_broadcastto = inner.DynamicBroadcastTo()
@fns
def before(data, shape):
shape = d_shape(shape)
return d_broadcastto(data, shape)
return fns[tag]