!20709 Fix lessbn code pclint.

Merge pull request !20709 from linqingke/highest-performance
This commit is contained in:
i-robot 2021-07-23 01:16:54 +00:00 committed by Gitee
commit f03dd82c11
4 changed files with 38 additions and 24 deletions

View File

@ -16,6 +16,9 @@
#include "frontend/optimizer/irpass/less_batch_normalization.h"
#include <set>
#include <unordered_map>
namespace mindspore {
namespace opt {
namespace irpass {
@ -302,7 +305,7 @@ bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNode
if (IsNotRealUseNode(node)) {
const auto &cnode = node->cast<CNodePtr>();
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
manager->Replace(cnode, new_cnode);
(void)manager->Replace(cnode, new_cnode);
continue;
}
need_remove = false;
@ -322,17 +325,18 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
MS_EXCEPTION_IF_NULL(root_graph);
std::vector<AnfNodePtr> real_remove_parameter_list;
std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
std::back_inserter(real_remove_parameter_list),
[&manager](const AnfNodePtr &param) { return IsRealRemoveParameterNode(manager, param); });
(void)std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
std::back_inserter(real_remove_parameter_list),
[&manager](const AnfNodePtr &param) { return IsRealRemoveParameterNode(manager, param); });
auto root_parameters = root_graph->parameters();
size_t origin_param_count = root_parameters.size();
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
[&real_remove_parameter_list](const AnfNodePtr &node) {
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
}),
root_parameters.end());
(void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
[&real_remove_parameter_list](const AnfNodePtr &node) {
return NeedRemove(node->cast<ParameterPtr>(),
real_remove_parameter_list);
}),
root_parameters.end());
size_t remove_param_count = origin_param_count - root_parameters.size();
size_t hyper_param_count = root_graph->hyper_param_count();
if (remove_param_count > hyper_param_count) {
@ -346,12 +350,12 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
} // namespace
bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
const kStructureTuple &patternTuple) {
const kStructureTuple &patternTuple) const {
if (index < 0) {
return false;
}
const auto &use_pattern = std::get<1>(patternTuple);
int32_t use_index = index % use_pattern.size();
int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) {
return false;
}
@ -391,7 +395,7 @@ void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vect
}
const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
remove_node_list_.insert(cnode);
(void)remove_node_list_.insert(cnode);
}
}
@ -408,7 +412,7 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
size_t sum_match_node = 0;
std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) {
sum_match_node += std::get<0>(t);
total_match_node_.emplace_back(sum_match_node);
(void)total_match_node_.emplace_back(sum_match_node);
});
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->inputs().empty()) {
@ -434,16 +438,16 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
for (auto &iter : remove_node_list_) {
// Need to remove batchnorm's parameter input.
if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
std::back_inserter(remove_load_list),
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
std::transform(
(void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
std::back_inserter(remove_load_list),
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
(void)std::transform(
remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
[](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
}
// Remove useless node.
auto input_cnode = iter->input(kValidResidualStructureIndex);
manager->Replace(iter, input_cnode);
(void)manager->Replace(iter, input_cnode);
}
RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
@ -471,7 +475,7 @@ void LessBatchNormalization::Visit(const CNodePtr &cnode) {
void LessBatchNormalization::Reset() {
remove_node_list_.clear();
total_match_node_.clear();
total_match_node_.emplace_back(0);
(void)total_match_node_.emplace_back(0);
match_node_ = 0;
match_branch_ = 0;
is_match_ = false;

View File

@ -38,7 +38,7 @@ class LessBatchNormalization : public AnfVisitor {
void Visit(const CNodePtr &cnode) override;
void Reset();
void IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple);
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple) const;
bool MatchGraphStructure(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
private:

View File

@ -17,7 +17,8 @@ import copy
from mindspore.nn.cell import Cell
from mindspore.nn.optim import LARS
from mindspore import log as logger
from mindspore.common import Parameter
from mindspore.common import Parameter, Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -136,6 +137,8 @@ class ParameterProcess:
group_params = []
params_name = [param.name for param in parameters]
new_params_count = copy.deepcopy(params_name)
new_params_clone = {}
max_key_number = 0
for group_param in origin_params_copy:
if 'order_params' in group_param.keys():
new_group_param = copy.deepcopy(group_param)
@ -151,12 +154,19 @@ class ParameterProcess:
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
if len(group_param.keys()) > max_key_number:
max_key_number = len(group_param.keys())
new_params_clone = copy.deepcopy(group_param)
if new_params_count:
params_value = []
for param in new_params_count:
index = params_name.index(param)
params_value.append(parameters[index])
group_params.append({"params": params_value})
if new_params_clone:
new_params_clone['params'] = params_value
group_params.append(new_params_clone)
else:
group_params.append({"params": params_value})
return group_params
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")

View File

@ -28,8 +28,8 @@ pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 5
keep_checkpoint_max: 10
warmup_epochs: 0
lr_decay_mode: "linear"
warmup_epochs: 5
lr_decay_mode: "cosine"
use_label_smooth: True
label_smooth_factor: 0.1
lr_init: 0