forked from mindspore-Ecosystem/mindspore
[Less BN]Eliminating BN according to pattern matching.
This commit is contained in:
parent
782cac9119
commit
a89a0d4810
|
@ -31,6 +31,7 @@
|
||||||
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
|
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
|
||||||
#include "frontend/optimizer/irpass/merge_addn.h"
|
#include "frontend/optimizer/irpass/merge_addn.h"
|
||||||
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
|
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
|
||||||
|
#include "frontend/optimizer/irpass/less_batch_normalization.h"
|
||||||
#include "frontend/optimizer/irpass/minmax_grad.h"
|
#include "frontend/optimizer/irpass/minmax_grad.h"
|
||||||
#include "frontend/optimizer/irpass/param_replace.h"
|
#include "frontend/optimizer/irpass/param_replace.h"
|
||||||
#include "frontend/optimizer/irpass/partial_eliminate.h"
|
#include "frontend/optimizer/irpass/partial_eliminate.h"
|
||||||
|
@ -143,6 +144,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
accumulaten_eliminater_ =
|
accumulaten_eliminater_ =
|
||||||
MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
|
MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
|
||||||
|
|
||||||
|
// Accelerated Algorithm
|
||||||
|
less_batch_normalization_ =
|
||||||
|
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization", prim::kPrimAdd);
|
||||||
|
|
||||||
// inline
|
// inline
|
||||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||||
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
|
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
|
||||||
|
|
|
@ -84,6 +84,9 @@ class OptimizeIRPassLib {
|
||||||
// AccumulateNV2
|
// AccumulateNV2
|
||||||
SubstitutionPtr accumulaten_eliminater_;
|
SubstitutionPtr accumulaten_eliminater_;
|
||||||
|
|
||||||
|
// Accelerated Algorithm
|
||||||
|
SubstitutionPtr less_batch_normalization_;
|
||||||
|
|
||||||
// Gradient irpasses
|
// Gradient irpasses
|
||||||
SubstitutionPtr minmaximum_grad_;
|
SubstitutionPtr minmaximum_grad_;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,238 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "frontend/optimizer/irpass/less_batch_normalization.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
namespace {
|
||||||
|
const char kLessBatchNormalizationPassName[] = "less_bn";
|
||||||
|
constexpr auto kValidResidualStructureIndex = 1;
|
||||||
|
constexpr auto kBNParametersStartIndex = 2;
|
||||||
|
// Pattern 1
|
||||||
|
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||||
|
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
|
||||||
|
constexpr auto kFirstBranchPattern1 = 12;
|
||||||
|
constexpr auto kSecondBranchPattern1 = 3;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
||||||
|
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
||||||
|
{kFirstBranchPattern1,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
|
{kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
|
||||||
|
{kSecondBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||||
|
// Pattern 2
|
||||||
|
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||||
|
// ↘ -> -> ... ... ... -> -> ↗
|
||||||
|
constexpr auto kFirstBranchPattern2 = 12;
|
||||||
|
constexpr auto kSecondBranchPattern2 = 1;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
||||||
|
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
||||||
|
{kFirstBranchPattern2,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
|
{kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
|
||||||
|
{kSecondBranchPattern2, {prim::kPrimRelu}, {SIZE_MAX, SIZE_MAX}}};
|
||||||
|
// Pattern 3
|
||||||
|
// Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
|
||||||
|
// ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗
|
||||||
|
constexpr auto kFirstBranchPattern3 = 11;
|
||||||
|
constexpr auto kSecondBranchPattern3 = 3;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
||||||
|
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||||
|
{kFirstBranchPattern3,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||||
|
prim::kPrimConv2D},
|
||||||
|
{kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
|
||||||
|
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||||
|
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||||
|
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
||||||
|
|
||||||
|
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
|
||||||
|
if (a == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
|
||||||
|
return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
|
||||||
|
auto param_output = manager->node_users().find(parameter);
|
||||||
|
if (param_output == manager->node_users().end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
|
||||||
|
const std::vector<AnfNodePtr> &remove_parameter_list) {
|
||||||
|
auto roots = manager->roots();
|
||||||
|
if (roots.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto root_graph = *(roots.begin());
|
||||||
|
MS_EXCEPTION_IF_NULL(root_graph);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> real_remove_parameter_list;
|
||||||
|
std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
|
||||||
|
std::back_inserter(real_remove_parameter_list),
|
||||||
|
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||||
|
|
||||||
|
auto root_parameters = root_graph->parameters();
|
||||||
|
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||||
|
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||||
|
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
||||||
|
}),
|
||||||
|
root_parameters.end());
|
||||||
|
|
||||||
|
manager->SetParameters(root_graph, root_parameters);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
|
||||||
|
const kStructureTuple &patternTuple) {
|
||||||
|
if (index < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &use_pattern = std::get<1>(patternTuple);
|
||||||
|
int32_t use_index = index % use_pattern.size();
|
||||||
|
if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
|
||||||
|
const std::vector<kStructureTuple> &match_pattern) {
|
||||||
|
if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
|
||||||
|
const auto &pattern = match_pattern[match_branch_];
|
||||||
|
if (!MatchStructureNode(cnode, index, pattern)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
match_node_++;
|
||||||
|
if (match_node_ == total_match_node_.back()) {
|
||||||
|
is_match_ = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (match_node_ == total_match_node_[match_branch_ + 1]) {
|
||||||
|
match_branch_++;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
|
||||||
|
if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (match_pattern.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
|
||||||
|
if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
|
||||||
|
remove_node_list_.insert(cnode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||||
|
const auto &fg = node->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
if (!fg->has_attr(kLessBatchNormalizationPassName)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
match_pattern_ = 0;
|
||||||
|
while (match_pattern_ < kNeedMatchPattern.size()) {
|
||||||
|
Reset();
|
||||||
|
const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
|
||||||
|
size_t sum_match_node = 0;
|
||||||
|
std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) {
|
||||||
|
sum_match_node += std::get<0>(t);
|
||||||
|
total_match_node_.emplace_back(sum_match_node);
|
||||||
|
});
|
||||||
|
AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node);
|
||||||
|
if (is_match_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
match_pattern_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_match_ || remove_node_list_.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto manager = optimizer->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
std::vector<AnfNodePtr> remove_load_list;
|
||||||
|
std::vector<AnfNodePtr> remove_parameter_list;
|
||||||
|
for (auto &iter : remove_node_list_) {
|
||||||
|
// Need to remove batchnorm's parameter input.
|
||||||
|
if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
|
||||||
|
std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
|
||||||
|
std::back_inserter(remove_load_list),
|
||||||
|
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
|
||||||
|
std::transform(
|
||||||
|
remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
|
||||||
|
[](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
|
||||||
|
}
|
||||||
|
// Remove useless node.
|
||||||
|
auto input_cnode = iter->input(kValidResidualStructureIndex);
|
||||||
|
manager->Replace(iter, input_cnode);
|
||||||
|
}
|
||||||
|
RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessBatchNormalization::Visit(const CNodePtr &cnode) {
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
|
||||||
|
IsRemoveNode(cnode, current_pattern);
|
||||||
|
if (!MatchGraphStructure(cnode, current_pattern)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto search_input = cnode->input(kValidResidualStructureIndex);
|
||||||
|
if (search_input != nullptr && search_input->isa<CNode>()) {
|
||||||
|
this->Visit(search_input->cast<CNodePtr>());
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessBatchNormalization::Reset() {
|
||||||
|
remove_node_list_.clear();
|
||||||
|
total_match_node_.clear();
|
||||||
|
total_match_node_.emplace_back(0);
|
||||||
|
match_node_ = 0;
|
||||||
|
match_branch_ = 0;
|
||||||
|
is_match_ = false;
|
||||||
|
}
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <tuple>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
#include "frontend/optimizer/optimizer.h"
|
||||||
|
#include "frontend/optimizer/anf_visitor.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
using kStructureTuple = std::tuple<size_t, std::vector<PrimitivePtr>, std::pair<size_t, size_t>>;
|
||||||
|
class LessBatchNormalization : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||||
|
void Visit(const CNodePtr &cnode) override;
|
||||||
|
void Reset();
|
||||||
|
void IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||||
|
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple);
|
||||||
|
bool MatchGraphStructure(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_set<CNodePtr> remove_node_list_{};
|
||||||
|
std::vector<size_t> total_match_node_{0};
|
||||||
|
size_t match_node_{0};
|
||||||
|
size_t match_branch_{0};
|
||||||
|
size_t match_pattern_{0};
|
||||||
|
bool is_match_{false};
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LESS_BATCH_NORMALIZATION_H_
|
|
@ -167,6 +167,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.mini_step_allgather_replace_,
|
irpass.mini_step_allgather_replace_,
|
||||||
},
|
},
|
||||||
false, true);
|
false, true);
|
||||||
|
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({
|
||||||
|
irpass.less_batch_normalization_,
|
||||||
|
});
|
||||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||||
|
|
||||||
|
@ -177,6 +180,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||||
OptPassGroupMap map_a({{"a_1", a_1},
|
OptPassGroupMap map_a({{"a_1", a_1},
|
||||||
{"a_2", a_2},
|
{"a_2", a_2},
|
||||||
|
{"accelerated_algorithm", accelerated_algorithm},
|
||||||
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
|
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
|
||||||
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
|
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
|
||||||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Accelerating.
|
||||||
|
|
||||||
|
Provide auto accelerating for network, such as Less BN.
|
||||||
|
"""
|
||||||
|
from .less_batch_normalization import LessBN
|
||||||
|
|
||||||
|
__all__ = ['LessBN']
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""less batch normalization"""
|
||||||
|
from ..cell import Cell
|
||||||
|
|
||||||
|
class LessBN(Cell):
|
||||||
|
"""
|
||||||
|
Reduce the number of BN automatically to improve the network performance
|
||||||
|
and ensure the network accuracy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): Network to be modified.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> network = acc.LessBN(network)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
super(LessBN, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.network.set_acc("less_bn")
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return self.network(*inputs)
|
|
@ -1036,6 +1036,29 @@ class Cell(Cell_):
|
||||||
self.add_flags_recursive(**flags)
|
self.add_flags_recursive(**flags)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def set_acc(self, acc_type):
|
||||||
|
"""
|
||||||
|
In order to improve the network performance, configure the network auto enable to
|
||||||
|
accelerate the algorithm in the algorithm library.
|
||||||
|
|
||||||
|
If `acc_type is not in the algorithm library`, Please view the algorithm in the algorithm library
|
||||||
|
through `algorithm library`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
acc_type (:str:`less_bn`): accelerate algorithm.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If acc_type is not in the algorithm library.
|
||||||
|
"""
|
||||||
|
if acc_type not in ("less_bn",):
|
||||||
|
raise ValueError("acc_type is not in the algorithm library.")
|
||||||
|
flags = {"less_bn": acc_type == "less_bn"}
|
||||||
|
self.add_flags_recursive(**flags)
|
||||||
|
return self
|
||||||
|
|
||||||
def set_grad(self, requires_grad=True):
|
def set_grad(self, requires_grad=True):
|
||||||
"""
|
"""
|
||||||
Sets the cell flag for gradient.
|
Sets the cell flag for gradient.
|
||||||
|
|
Loading…
Reference in New Issue