!20709 Fix lessbn code pclint.
Merge pull request !20709 from linqingke/highest-performance
This commit is contained in:
commit
f03dd82c11
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "frontend/optimizer/irpass/less_batch_normalization.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
@ -302,7 +305,7 @@ bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNode
|
|||
if (IsNotRealUseNode(node)) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
|
||||
manager->Replace(cnode, new_cnode);
|
||||
(void)manager->Replace(cnode, new_cnode);
|
||||
continue;
|
||||
}
|
||||
need_remove = false;
|
||||
|
@ -322,17 +325,18 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
|
|||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
|
||||
std::vector<AnfNodePtr> real_remove_parameter_list;
|
||||
std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
|
||||
std::back_inserter(real_remove_parameter_list),
|
||||
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||
(void)std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
|
||||
std::back_inserter(real_remove_parameter_list),
|
||||
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||
|
||||
auto root_parameters = root_graph->parameters();
|
||||
size_t origin_param_count = root_parameters.size();
|
||||
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
||||
}),
|
||||
root_parameters.end());
|
||||
(void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||
return NeedRemove(node->cast<ParameterPtr>(),
|
||||
real_remove_parameter_list);
|
||||
}),
|
||||
root_parameters.end());
|
||||
size_t remove_param_count = origin_param_count - root_parameters.size();
|
||||
size_t hyper_param_count = root_graph->hyper_param_count();
|
||||
if (remove_param_count > hyper_param_count) {
|
||||
|
@ -346,12 +350,12 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
|
|||
} // namespace
|
||||
|
||||
bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
|
||||
const kStructureTuple &patternTuple) {
|
||||
const kStructureTuple &patternTuple) const {
|
||||
if (index < 0) {
|
||||
return false;
|
||||
}
|
||||
const auto &use_pattern = std::get<1>(patternTuple);
|
||||
int32_t use_index = index % use_pattern.size();
|
||||
int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
|
||||
if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) {
|
||||
return false;
|
||||
}
|
||||
|
@ -391,7 +395,7 @@ void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vect
|
|||
}
|
||||
const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
|
||||
if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
|
||||
remove_node_list_.insert(cnode);
|
||||
(void)remove_node_list_.insert(cnode);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,7 +412,7 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
|
|||
size_t sum_match_node = 0;
|
||||
std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) {
|
||||
sum_match_node += std::get<0>(t);
|
||||
total_match_node_.emplace_back(sum_match_node);
|
||||
(void)total_match_node_.emplace_back(sum_match_node);
|
||||
});
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->inputs().empty()) {
|
||||
|
@ -434,16 +438,16 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
|
|||
for (auto &iter : remove_node_list_) {
|
||||
// Need to remove batchnorm's parameter input.
|
||||
if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
|
||||
std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
|
||||
std::back_inserter(remove_load_list),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
|
||||
std::transform(
|
||||
(void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
|
||||
std::back_inserter(remove_load_list),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
|
||||
(void)std::transform(
|
||||
remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
|
||||
[](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
|
||||
}
|
||||
// Remove useless node.
|
||||
auto input_cnode = iter->input(kValidResidualStructureIndex);
|
||||
manager->Replace(iter, input_cnode);
|
||||
(void)manager->Replace(iter, input_cnode);
|
||||
}
|
||||
RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
|
||||
|
||||
|
@ -471,7 +475,7 @@ void LessBatchNormalization::Visit(const CNodePtr &cnode) {
|
|||
void LessBatchNormalization::Reset() {
|
||||
remove_node_list_.clear();
|
||||
total_match_node_.clear();
|
||||
total_match_node_.emplace_back(0);
|
||||
(void)total_match_node_.emplace_back(0);
|
||||
match_node_ = 0;
|
||||
match_branch_ = 0;
|
||||
is_match_ = false;
|
||||
|
|
|
@ -38,7 +38,7 @@ class LessBatchNormalization : public AnfVisitor {
|
|||
void Visit(const CNodePtr &cnode) override;
|
||||
void Reset();
|
||||
void IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple);
|
||||
bool MatchStructureNode(const CNodePtr &cnode, const int32_t index, const kStructureTuple &patternTuple) const;
|
||||
bool MatchGraphStructure(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern);
|
||||
|
||||
private:
|
||||
|
|
|
@ -17,7 +17,8 @@ import copy
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.optim import LARS
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.common import Parameter, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -136,6 +137,8 @@ class ParameterProcess:
|
|||
group_params = []
|
||||
params_name = [param.name for param in parameters]
|
||||
new_params_count = copy.deepcopy(params_name)
|
||||
new_params_clone = {}
|
||||
max_key_number = 0
|
||||
for group_param in origin_params_copy:
|
||||
if 'order_params' in group_param.keys():
|
||||
new_group_param = copy.deepcopy(group_param)
|
||||
|
@ -151,12 +154,19 @@ class ParameterProcess:
|
|||
new_group_param = copy.deepcopy(group_param)
|
||||
new_group_param['params'] = params_value
|
||||
group_params.append(new_group_param)
|
||||
if len(group_param.keys()) > max_key_number:
|
||||
max_key_number = len(group_param.keys())
|
||||
new_params_clone = copy.deepcopy(group_param)
|
||||
if new_params_count:
|
||||
params_value = []
|
||||
for param in new_params_count:
|
||||
index = params_name.index(param)
|
||||
params_value.append(parameters[index])
|
||||
group_params.append({"params": params_value})
|
||||
if new_params_clone:
|
||||
new_params_clone['params'] = params_value
|
||||
group_params.append(new_params_clone)
|
||||
else:
|
||||
group_params.append({"params": params_value})
|
||||
return group_params
|
||||
|
||||
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
||||
|
|
|
@ -28,8 +28,8 @@ pretrain_epoch_size: 0
|
|||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
warmup_epochs: 0
|
||||
lr_decay_mode: "linear"
|
||||
warmup_epochs: 5
|
||||
lr_decay_mode: "cosine"
|
||||
use_label_smooth: True
|
||||
label_smooth_factor: 0.1
|
||||
lr_init: 0
|
||||
|
|
Loading…
Reference in New Issue