forked from mindspore-Ecosystem/mindspore
[Less BN]Eliminating BN according to pattern matching.
This commit is contained in:
parent
782cac9119
commit
a89a0d4810
|
@ -31,6 +31,7 @@
|
|||
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/merge_addn.h"
|
||||
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/less_batch_normalization.h"
|
||||
#include "frontend/optimizer/irpass/minmax_grad.h"
|
||||
#include "frontend/optimizer/irpass/param_replace.h"
|
||||
#include "frontend/optimizer/irpass/partial_eliminate.h"
|
||||
|
@ -143,6 +144,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
accumulaten_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
|
||||
|
||||
// Accelerated Algorithm
|
||||
less_batch_normalization_ =
|
||||
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization", prim::kPrimAdd);
|
||||
|
||||
// inline
|
||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
|
||||
|
|
|
@ -84,6 +84,9 @@ class OptimizeIRPassLib {
|
|||
// AccumulateNV2
|
||||
SubstitutionPtr accumulaten_eliminater_;
|
||||
|
||||
// Accelerated Algorithm
|
||||
SubstitutionPtr less_batch_normalization_;
|
||||
|
||||
// Gradient irpasses
|
||||
SubstitutionPtr minmaximum_grad_;
|
||||
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/less_batch_normalization.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace {
|
||||
const char kLessBatchNormalizationPassName[] = "less_bn";
|
||||
constexpr auto kValidResidualStructureIndex = 1;
|
||||
constexpr auto kBNParametersStartIndex = 2;
|
||||
// Pattern 1
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern1 = 12;
|
||||
constexpr auto kSecondBranchPattern1 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
||||
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
||||
{kFirstBranchPattern1,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
|
||||
{kSecondBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||
// Pattern 2
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> ... ... ... -> -> ↗
|
||||
constexpr auto kFirstBranchPattern2 = 12;
|
||||
constexpr auto kSecondBranchPattern2 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
||||
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
||||
{kFirstBranchPattern2,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
|
||||
{kSecondBranchPattern2, {prim::kPrimRelu}, {SIZE_MAX, SIZE_MAX}}};
|
||||
// Pattern 3
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
|
||||
// ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗
|
||||
constexpr auto kFirstBranchPattern3 = 11;
|
||||
constexpr auto kSecondBranchPattern3 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
||||
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||
{kFirstBranchPattern3,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||
prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
|
||||
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
||||
|
||||
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
|
||||
if (a == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
|
||||
return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
|
||||
});
|
||||
}
|
||||
|
||||
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
|
||||
auto param_output = manager->node_users().find(parameter);
|
||||
if (param_output == manager->node_users().end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
|
||||
const std::vector<AnfNodePtr> &remove_parameter_list) {
|
||||
auto roots = manager->roots();
|
||||
if (roots.size() != 1) {
|
||||
MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
|
||||
return;
|
||||
}
|
||||
auto root_graph = *(roots.begin());
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
|
||||
std::vector<AnfNodePtr> real_remove_parameter_list;
|
||||
std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
|
||||
std::back_inserter(real_remove_parameter_list),
|
||||
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||
|
||||
auto root_parameters = root_graph->parameters();
|
||||
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
||||
}),
|
||||
root_parameters.end());
|
||||
|
||||
manager->SetParameters(root_graph, root_parameters);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
|
||||
const kStructureTuple &patternTuple) {
|
||||
if (index < 0) {
|
||||
return false;
|
||||
}
|
||||
const auto &use_pattern = std::get<1>(patternTuple);
|
||||
int32_t use_index = index % use_pattern.size();
|
||||
if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
|
||||
const std::vector<kStructureTuple> &match_pattern) {
|
||||
if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
|
||||
const auto &pattern = match_pattern[match_branch_];
|
||||
if (!MatchStructureNode(cnode, index, pattern)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
match_node_++;
|
||||
if (match_node_ == total_match_node_.back()) {
|
||||
is_match_ = true;
|
||||
return false;
|
||||
}
|
||||
if (match_node_ == total_match_node_[match_branch_ + 1]) {
|
||||
match_branch_++;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
return;
|
||||
}
|
||||
if (match_pattern.empty()) {
|
||||
return;
|
||||
}
|
||||
const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
|
||||
if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
|
||||
remove_node_list_.insert(cnode);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
const auto &fg = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (!fg->has_attr(kLessBatchNormalizationPassName)) {
|
||||
return nullptr;
|
||||
}
|
||||
match_pattern_ = 0;
|
||||
while (match_pattern_ < kNeedMatchPattern.size()) {
|
||||
Reset();
|
||||
const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
|
||||
size_t sum_match_node = 0;
|
||||
std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) {
|
||||
sum_match_node += std::get<0>(t);
|
||||
total_match_node_.emplace_back(sum_match_node);
|
||||
});
|
||||
AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node);
|
||||
if (is_match_) {
|
||||
break;
|
||||
}
|
||||
match_pattern_++;
|
||||
}
|
||||
|
||||
if (!is_match_ || remove_node_list_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto manager = optimizer->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> remove_load_list;
|
||||
std::vector<AnfNodePtr> remove_parameter_list;
|
||||
for (auto &iter : remove_node_list_) {
|
||||
// Need to remove batchnorm's parameter input.
|
||||
if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
|
||||
std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
|
||||
std::back_inserter(remove_load_list),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
|
||||
std::transform(
|
||||
remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
|
||||
[](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
|
||||
}
|
||||
// Remove useless node.
|
||||
auto input_cnode = iter->input(kValidResidualStructureIndex);
|
||||
manager->Replace(iter, input_cnode);
|
||||
}
|
||||
RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void LessBatchNormalization::Visit(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
|
||||
IsRemoveNode(cnode, current_pattern);
|
||||
if (!MatchGraphStructure(cnode, current_pattern)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto search_input = cnode->input(kValidResidualStructureIndex);
|
||||
if (search_input != nullptr && search_input->isa<CNode>()) {
|
||||
this->Visit(search_input->cast<CNodePtr>());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void LessBatchNormalization::Reset() {
|
||||
remove_node_list_.clear();
|
||||
total_match_node_.clear();
|
||||
total_match_node_.emplace_back(0);
|
||||
match_node_ = 0;
|
||||
match_branch_ = 0;
|
||||
is_match_ = false;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
using kStructureTuple = std::tuple<size_t, std::vector<PrimitivePtr>, std::pair<size_t, size_t>>;
|
||||
class LessBatchNormalization : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
void Visit(const CNodePtr &cnode) override;
|
||||
void Reset();
|
||||
void IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple);
|
||||
bool MatchGraphStructure(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||
|
||||
private:
|
||||
std::unordered_set<CNodePtr> remove_node_list_{};
|
||||
std::vector<size_t> total_match_node_{0};
|
||||
size_t match_node_{0};
|
||||
size_t match_branch_{0};
|
||||
size_t match_pattern_{0};
|
||||
bool is_match_{false};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
|
@ -167,6 +167,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.mini_step_allgather_replace_,
|
||||
},
|
||||
false, true);
|
||||
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({
|
||||
irpass.less_batch_normalization_,
|
||||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||
|
||||
|
@ -177,6 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||
OptPassGroupMap map_a({{"a_1", a_1},
|
||||
{"a_2", a_2},
|
||||
{"accelerated_algorithm", accelerated_algorithm},
|
||||
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
|
||||
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
|
||||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Accelerating.
|
||||
|
||||
Provide auto accelerating for network, such as Less BN.
|
||||
"""
|
||||
from .less_batch_normalization import LessBN
|
||||
|
||||
__all__ = ['LessBN']
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""less batch normalization"""
|
||||
from ..cell import Cell
|
||||
|
||||
class LessBN(Cell):
|
||||
"""
|
||||
Reduce the number of BN automatically to improve the network performance
|
||||
and ensure the network accuracy.
|
||||
|
||||
Args:
|
||||
network (Cell): Network to be modified.
|
||||
|
||||
Examples:
|
||||
>>> network = acc.LessBN(network)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(LessBN, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_acc("less_bn")
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.network(*inputs)
|
|
@ -1036,6 +1036,29 @@ class Cell(Cell_):
|
|||
self.add_flags_recursive(**flags)
|
||||
return self
|
||||
|
||||
def set_acc(self, acc_type):
|
||||
"""
|
||||
In order to improve the network performance, configure the network auto enable to
|
||||
accelerate the algorithm in the algorithm library.
|
||||
|
||||
If `acc_type is not in the algorithm library`, Please view the algorithm in the algorithm library
|
||||
through `algorithm library`.
|
||||
|
||||
Note:
|
||||
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
||||
|
||||
Args:
|
||||
acc_type (:str:`less_bn`): accelerate algorithm.
|
||||
|
||||
Raises:
|
||||
ValueError: If acc_type is not in the algorithm library.
|
||||
"""
|
||||
if acc_type not in ("less_bn",):
|
||||
raise ValueError("acc_type is not in the algorithm library.")
|
||||
flags = {"less_bn": acc_type == "less_bn"}
|
||||
self.add_flags_recursive(**flags)
|
||||
return self
|
||||
|
||||
def set_grad(self, requires_grad=True):
|
||||
"""
|
||||
Sets the cell flag for gradient.
|
||||
|
|
Loading…
Reference in New Issue