forked from mindspore-Ecosystem/mindspore
fix convert dynamic broadcast to pass
This commit is contained in:
parent
1bfe24bd98
commit
5b91ce9b64
|
@ -23,10 +23,11 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const auto kV = "V";
|
||||
const auto kA = "A";
|
||||
const auto kVs = "Vs";
|
||||
const auto kMBroadcastTo = "m_broadcast_to";
|
||||
const auto kRBroadcastTo = "r_broadcast_to";
|
||||
AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
|
||||
AnfNodePtr BuildBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMBroadcastTo);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto broadcast_to_op_name = prim::kPrimBroadcastTo->name();
|
||||
|
@ -37,6 +38,8 @@ AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
|
|||
CNodePtr broadcast_to_node =
|
||||
opt::NewCNode({NewValueNode(std::make_shared<Primitive>(broadcast_to_op_name)), input_x}, func_graph, {node});
|
||||
MS_EXCEPTION_IF_NULL(broadcast_to_node);
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(node->abstract()->BuildShape());
|
||||
broadcast_to_node->set_abstract(node->abstract());
|
||||
auto shape_ptr = node->abstract()->BuildShape()->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
|
@ -55,11 +58,11 @@ bool ConvertDynamicBroadcastTo::CheckMatchedDAG(const PatternMap &, const FuncGr
|
|||
}
|
||||
|
||||
void ConvertDynamicBroadcastTo::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kV).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV});
|
||||
(*src_pattern).AddVar(kA).AddSeqVar(kVs).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kA, kVs});
|
||||
}
|
||||
|
||||
void ConvertDynamicBroadcastTo::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}, BuildDynamicBroadcastTo);
|
||||
(*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimBroadcastTo, kA}, BuildBroadcastTo);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
class TestConvertDynmicBroadcastToPass : public BackendCommon {
|
||||
public:
|
||||
TestConvertDynmicBroadcastToPass() : getPyFun_("gtest_input.pre_activate.convert_dynamic_broadcast_to_test", true) {}
|
||||
~TestConvertDynmicBroadcastToPass() override = default;
|
||||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
};
|
||||
|
||||
/// Feature: ConvertDynmicBroadcastTo Pass
|
||||
/// Description: ConvertDynmicBroadcastTo rewrite graph
|
||||
/// Expectation: Get correct Graph
|
||||
TEST_F(TestConvertDynmicBroadcastToPass, TestConvert) {
|
||||
// build func graph
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_dyn_broadcast", "before");
|
||||
std::vector<int64_t> shpx{3};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shpx);
|
||||
std::vector<int64_t> shpy{2, 3};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shpy);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
auto fg = GetFuncGraph(g, args_spec_list);
|
||||
|
||||
bool has_dyn = false;
|
||||
for (const auto &n : TopoSort(fg->get_return())) {
|
||||
if (IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo)) {
|
||||
has_dyn = true;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(has_dyn);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<ConvertDynamicBroadcastTo>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
for (const auto &n : TopoSort(fg->get_return())) {
|
||||
ASSERT_FALSE(IsPrimitiveCNode(n, prim::kPrimDynamicBroadcastTo));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.operations as ops
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fn_dict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fn_dict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fn_dict.get(name)
|
||||
|
||||
|
||||
def test_dyn_broadcast(tag):
|
||||
"""
|
||||
Feature: ConvertDynmicBroadcastTo Pass
|
||||
Description: ConvertDynmicBroadcastTo rewrite graph.
|
||||
Expectation: Get correct Graph.
|
||||
"""
|
||||
fns = FnDict()
|
||||
d_shape = ops.TensorShape()
|
||||
d_broadcastto = inner.DynamicBroadcastTo()
|
||||
|
||||
@fns
|
||||
def before(data, shape):
|
||||
shape = d_shape(shape)
|
||||
return d_broadcastto(data, shape)
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue