diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index db5b26de5ad..a8fd343e146 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -31,6 +31,7 @@ #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" #include "frontend/optimizer/irpass/merge_addn.h" #include "frontend/optimizer/irpass/accumulaten_eliminate.h" +#include "frontend/optimizer/irpass/less_batch_normalization.h" #include "frontend/optimizer/irpass/minmax_grad.h" #include "frontend/optimizer/irpass/param_replace.h" #include "frontend/optimizer/irpass/partial_eliminate.h" @@ -143,6 +144,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { accumulaten_eliminater_ = MakeSubstitution(std::make_shared(), "accumulaten_eliminater", prim::kPrimAccumulateNV2); + // Accelerated Algorithm + less_batch_normalization_ = + MakeSubstitution(std::make_shared(), "less_batch_normalization", prim::kPrimAdd); + // inline inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); inline_without_move_ = MakeSubstitution(std::make_shared(false), "inline", IsCNodeGraph); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 7be50cb9eb3..9241e470f8e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -84,6 +84,9 @@ class OptimizeIRPassLib { // AccumulateNV2 SubstitutionPtr accumulaten_eliminater_; + // Accelerated Algorithm + SubstitutionPtr less_batch_normalization_; + // Gradient irpasses SubstitutionPtr minmaximum_grad_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc new file mode 100644 index 00000000000..f2e26a61771 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/less_batch_normalization.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace { +const char kLessBatchNormalizationPassName[] = "less_bn"; +constexpr auto kValidResidualStructureIndex = 1; +constexpr auto kBNParametersStartIndex = 2; +// Pattern 1 +// Add -> BatchNorm -> Conv2D -> Relu ... -> End +// ↘ BatchNorm -> Conv2D -> -> -> -> ↗ +constexpr auto kFirstBranchPattern1 = 12; +constexpr auto kSecondBranchPattern1 = 3; +constexpr auto kFirstBranchStartIndexPattern1 = 4; +constexpr auto kFirstBranchEndIndexPattern1 = 11; +const std::vector ResidualStructureBasePattern{ + {kFirstBranchPattern1, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, + {kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}}, + {kSecondBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; +// Pattern 2 +// Add -> BatchNorm -> Conv2D -> Relu ... -> End +// ↘ -> -> ... ... ... -> -> ↗ +constexpr auto kFirstBranchPattern2 = 12; +constexpr auto kSecondBranchPattern2 = 1; +constexpr auto kFirstBranchStartIndexPattern2 = 4; +constexpr auto kFirstBranchEndIndexPattern2 = 11; +const std::vector ResidualStructureShortCutPattern{ + {kFirstBranchPattern2, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, + {kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}}, + {kSecondBranchPattern2, {prim::kPrimRelu}, {SIZE_MAX, SIZE_MAX}}}; +// Pattern 3 +// Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End +// ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗ +constexpr auto kFirstBranchPattern3 = 11; +constexpr auto kSecondBranchPattern3 = 3; +constexpr auto kFirstBranchStartIndexPattern3 = 4; +constexpr auto kFirstBranchEndIndexPattern3 = 10; +const std::vector ResidualStructureFirstStepPattern{ + {kFirstBranchPattern3, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, + prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, + prim::kPrimConv2D}, + {kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}}, + {kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; +static const std::vector> kNeedMatchPattern = { + ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern}; + +bool NeedRemove(const ParameterPtr &a, const std::vector ¶meter_list) { + if (a == nullptr) { + return false; + } + return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) { + return (b->isa() && a->name() == b->cast()->name()); + }); +} + +bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) { + auto param_output = manager->node_users().find(parameter); + if (param_output == manager->node_users().end()) { + return true; + } + + return false; +} + +void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager, + const std::vector &remove_parameter_list) { + auto roots = manager->roots(); + if (roots.size() != 1) { + MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid."; + return; + } + auto root_graph = *(roots.begin()); + MS_EXCEPTION_IF_NULL(root_graph); + + std::vector real_remove_parameter_list; + std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(), + std::back_inserter(real_remove_parameter_list), + [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); + + auto root_parameters = root_graph->parameters(); + root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), + [&real_remove_parameter_list](const AnfNodePtr &node) { + return NeedRemove(node->cast(), real_remove_parameter_list); + }), + root_parameters.end()); + + manager->SetParameters(root_graph, root_parameters); +} +} // namespace + +bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index, + const kStructureTuple &patternTuple) { + if (index < 0) { + return false; + } + const auto &use_pattern = std::get<1>(patternTuple); + int32_t use_index = index % use_pattern.size(); + if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) { + return false; + } + return true; +} + +bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode, + const std::vector &match_pattern) { + if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) { + return false; + } + + int32_t index = static_cast(match_node_) - static_cast(total_match_node_[match_branch_]); + const auto &pattern = match_pattern[match_branch_]; + if (!MatchStructureNode(cnode, index, pattern)) { + return false; + } + + match_node_++; + if (match_node_ == total_match_node_.back()) { + is_match_ = true; + return false; + } + if (match_node_ == total_match_node_[match_branch_ + 1]) { + match_branch_++; + return false; + } + return true; +} + +void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector &match_pattern) { + if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { + return; + } + if (match_pattern.empty()) { + return; + } + const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_)); + if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) { + remove_node_list_.insert(cnode); + } +} + +AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + const auto &fg = node->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (!fg->has_attr(kLessBatchNormalizationPassName)) { + return nullptr; + } + match_pattern_ = 0; + while (match_pattern_ < kNeedMatchPattern.size()) { + Reset(); + const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_); + size_t sum_match_node = 0; + std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) { + sum_match_node += std::get<0>(t); + total_match_node_.emplace_back(sum_match_node); + }); + AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node); + if (is_match_) { + break; + } + match_pattern_++; + } + + if (!is_match_ || remove_node_list_.empty()) { + return nullptr; + } + + auto manager = optimizer->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector remove_load_list; + std::vector remove_parameter_list; + for (auto &iter : remove_node_list_) { + // Need to remove batchnorm's parameter input. + if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) { + std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(), + std::back_inserter(remove_load_list), + [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); + std::transform( + remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list), + [](const AnfNodePtr &node) { return node->cast()->input(kValidResidualStructureIndex); }); + } + // Remove useless node. + auto input_cnode = iter->input(kValidResidualStructureIndex); + manager->Replace(iter, input_cnode); + } + RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list); + + return node; +} + +void LessBatchNormalization::Visit(const CNodePtr &cnode) { + if (cnode == nullptr) { + return; + } + + const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_); + IsRemoveNode(cnode, current_pattern); + if (!MatchGraphStructure(cnode, current_pattern)) { + return; + } + + auto search_input = cnode->input(kValidResidualStructureIndex); + if (search_input != nullptr && search_input->isa()) { + this->Visit(search_input->cast()); + } + return; +} + +void LessBatchNormalization::Reset() { + remove_node_list_.clear(); + total_match_node_.clear(); + total_match_node_.emplace_back(0); + match_node_ = 0; + match_branch_ = 0; + is_match_ = false; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.h b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.h new file mode 100644 index 00000000000..c78e967b9bf --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_ + +#include +#include +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +using kStructureTuple = std::tuple, std::pair>; +class LessBatchNormalization : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + void Visit(const CNodePtr &cnode) override; + void Reset(); + void IsRemoveNode(const CNodePtr &cnode, const std::vector &match_pattern); + bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple); + bool MatchGraphStructure(const CNodePtr &cnode, const std::vector &match_pattern); + + private: + std::unordered_set remove_node_list_{}; + std::vector total_match_node_{0}; + size_t match_node_{0}; + size_t match_branch_{0}; + size_t match_pattern_{0}; + bool is_match_{false}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 7d7fd28dc3c..542404d8a48 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -167,6 +167,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.mini_step_allgather_replace_, }, false, true); + opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({ + irpass.less_batch_normalization_, + }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::irpass::ResolveIRPassLib resolve_irpass; @@ -177,6 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, + {"accelerated_algorithm", accelerated_algorithm}, {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, {"parallel", opt::OptPassConfig(parallel::StepParallel)}, {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, diff --git a/mindspore/nn/acc/__init__.py b/mindspore/nn/acc/__init__.py new file mode 100644 index 00000000000..8e07749c83b --- /dev/null +++ b/mindspore/nn/acc/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Accelerating. + +Provide auto accelerating for network, such as Less BN. +""" +from .less_batch_normalization import LessBN + +__all__ = ['LessBN'] diff --git a/mindspore/nn/acc/less_batch_normalization.py b/mindspore/nn/acc/less_batch_normalization.py new file mode 100644 index 00000000000..95db064e540 --- /dev/null +++ b/mindspore/nn/acc/less_batch_normalization.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""less batch normalization""" +from ..cell import Cell + +class LessBN(Cell): + """ + Reduce the number of BN automatically to improve the network performance + and ensure the network accuracy. + + Args: + network (Cell): Network to be modified. + + Examples: + >>> network = acc.LessBN(network) + """ + + def __init__(self, network): + super(LessBN, self).__init__() + self.network = network + self.network.set_acc("less_bn") + + def construct(self, *inputs): + return self.network(*inputs) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index bd80dfd0738..d17658cb47b 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1036,6 +1036,29 @@ class Cell(Cell_): self.add_flags_recursive(**flags) return self + def set_acc(self, acc_type): + """ + In order to improve the network performance, configure the network auto enable to + accelerate the algorithm in the algorithm library. + + If `acc_type is not in the algorithm library`, Please view the algorithm in the algorithm library + through `algorithm library`. + + Note: + Some acceleration algorithms may affect the accuracy of the network, please choose carefully. + + Args: + acc_type (:str:`less_bn`): accelerate algorithm. + + Raises: + ValueError: If acc_type is not in the algorithm library. + """ + if acc_type not in ("less_bn",): + raise ValueError("acc_type is not in the algorithm library.") + flags = {"less_bn": acc_type == "less_bn"} + self.add_flags_recursive(**flags) + return self + def set_grad(self, requires_grad=True): """ Sets the cell flag for gradient.