fix_codex_pclint_r1.5

This commit is contained in:
lichenever 2021-11-12 16:05:03 +08:00
parent 4faaca4df9
commit 8abc711298
5 changed files with 58 additions and 45 deletions

View File

@ -487,6 +487,36 @@ void GatherInfo::InferInputsTensorMap() {
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
}
Shape GatherInfo::InferOutputsTensorMapSplitAxis() {
Shape tensor_map_out;
size_t param_size = inputs_shape_.at(0).size();
size_t index_size = inputs_shape_.at(1).size();
if (axis_ == 0) {
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
// the output is repeat calculation
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
} else {
tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
}
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
for (size_t i = 1; i < param_size; ++i) {
tensor_map_out.push_back(param_size - 1 - i);
}
} else {
for (size_t i = 0; i < param_size; ++i) {
if (i == LongToSize(axis_)) {
tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
} else {
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
tensor_map_out.push_back(MAP_NONE);
}
tensor_map_out.push_back(SizeToLong(i));
}
}
}
return tensor_map_out;
}
void GatherInfo::InferOutputsTensorMap() {
// infer output tensor map
size_t param_size = inputs_shape_.at(0).size();
@ -507,29 +537,7 @@ void GatherInfo::InferOutputsTensorMap() {
}
} else {
// param_strategy(axis) is not 1
if (axis_ == 0) {
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
// the output is repeat calculation
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
} else {
tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
}
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
for (size_t i = 1; i < param_size; ++i) {
tensor_map_out.push_back(param_size - 1 - i);
}
} else {
for (size_t i = 0; i < param_size; ++i) {
if (i == LongToSize(axis_)) {
tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
} else {
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
tensor_map_out.push_back(MAP_NONE);
}
tensor_map_out.push_back(SizeToLong(i));
}
}
}
tensor_map_out = InferOutputsTensorMapSplitAxis();
}
(void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
}

View File

@ -72,6 +72,7 @@ class GatherInfo : public OperatorInfo {
Status InferOffset();
Status InferGroup();
bool ShardBatchAndAxis(const Strategys &strategy) const;
Shape InferOutputsTensorMapSplitAxis();
int64_t axis_;
std::string target_ = DEVICE;

View File

@ -121,39 +121,40 @@ std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node,
return new_node_input;
}
AnfNodePtr GetAccuGrad(const std::vector<AnfNodePtr> &parameters, const std::string &weight_name) {
for (auto &param : parameters) {
if (!ParameterIsCloned(param)) {
continue;
}
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name().find(weight_name) != std::string::npos &&
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
return param;
}
}
return nullptr;
}
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
const std::string &instance_name, const std::string &weight_name) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(root->manager());
AnfNodePtr grad_accu = nullptr;
std::string op_name = op.first;
OperatorArgs arg_forward = op.second;
AnfNodePtr grad_accu = nullptr;
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (grad_accumulation_step > 1 || split_stage_num > 1) {
auto parameters = root->parameters();
bool find_grad_accu_node = false;
for (auto &param : parameters) {
if (!ParameterIsCloned(param)) {
continue;
}
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name().find(weight_name) != std::string::npos &&
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
find_grad_accu_node = true;
grad_accu = param;
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
break;
}
}
if (!find_grad_accu_node) {
grad_accu = GetAccuGrad(parameters, weight_name);
if (!grad_accu) {
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back();
@ -2756,7 +2757,8 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
}
std::vector<std::pair<int64_t, int64_t>> manual_shape;
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
(void)manual_shape.emplace_back(
std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
}
manual_shape_map[param_name] = manual_shape;
}

View File

@ -74,6 +74,8 @@ bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs);
bool AttrFound(const std::unordered_map<std::string, ValuePtr> &attrs, const std::string &target);
AnfNodePtr GetAccuGrad(const std::vector<AnfNodePtr> &parameters, const std::string &weight_name);
void MarkForwardCNode(const FuncGraphPtr &root);
bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);

View File

@ -251,9 +251,9 @@ class Cell(Cell_):
@pipeline_stage.setter
def pipeline_stage(self, value):
if isinstance(value, bool):
raise TypeError("'pipeline_stage' must be a int type, but got bool.")
raise TypeError("'pipeline_stage' must be an int type, but got bool.")
if not isinstance(value, int):
raise TypeError("'pipeline_stage' must be a int type.")
raise TypeError("'pipeline_stage' must be an int type.")
if value < 0:
raise TypeError("'pipeline_stage' can not be less than 0.")
self._pipeline_stage = value