!28746 Parallel module code Alarm clearing
Merge pull request !28746 from liuluobin/master
This commit is contained in:
commit
4d9437c0aa
|
@ -164,8 +164,8 @@ Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|||
if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
|
||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
|
||||
}
|
||||
if (strategies[0][axis] != 1) {
|
||||
strategies[0][axis] = 1;
|
||||
if (strategies[0][LongToSize(axis)] != 1) {
|
||||
strategies[0][LongToSize(axis)] = 1;
|
||||
MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
|
||||
}
|
||||
}
|
||||
|
@ -723,11 +723,9 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
|
|||
Dimensions PrepareReshapeOutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index) {
|
||||
Dimensions s;
|
||||
|
||||
auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
|
||||
auto input_shape = ops[incoming_op_index]->inputs_tensor_info()[0].shape();
|
||||
auto strategy = ops[incoming_op_index]->selected_strategy();
|
||||
|
||||
std::vector<int64_t> mapping;
|
||||
int64_t tmp_prod = 1;
|
||||
int64_t tmp_index = 0;
|
||||
|
@ -740,67 +738,63 @@ Dimensions PrepareReshapeOutputStrategy(const std::vector<std::shared_ptr<Operat
|
|||
// e.g. input_shape [2,2,2,2] output_shape [2,4,2], the mapping is [0,2,3,-1].
|
||||
if (output_shape.size() >= input_shape.size()) {
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
if (input_shape[i] < output_shape[tmp_index]) {
|
||||
if (input_shape[i] < output_shape[LongToSize(tmp_index)]) {
|
||||
break;
|
||||
} else {
|
||||
for (size_t j = tmp_index; j < output_shape.size(); j++) {
|
||||
tmp_prod *= output_shape[j];
|
||||
tmp_index++;
|
||||
if (input_shape[i] == tmp_prod) {
|
||||
tmp_prod = 1;
|
||||
mapping.push_back(i);
|
||||
break;
|
||||
} else {
|
||||
mapping.push_back(i);
|
||||
}
|
||||
}
|
||||
for (size_t j = LongToSize(tmp_index); j < output_shape.size(); j++) {
|
||||
tmp_prod *= output_shape[j];
|
||||
tmp_index++;
|
||||
if (input_shape[i] == tmp_prod) {
|
||||
tmp_prod = 1;
|
||||
mapping.push_back(i);
|
||||
break;
|
||||
}
|
||||
mapping.push_back(i);
|
||||
}
|
||||
}
|
||||
mapping.push_back(-1);
|
||||
} else {
|
||||
tmp_index = 0;
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if (output_shape[i] < input_shape[tmp_index]) {
|
||||
break;
|
||||
} else {
|
||||
for (size_t j = tmp_index; j < input_shape.size(); j++) {
|
||||
tmp_prod *= input_shape[j];
|
||||
if (output_shape[i] == tmp_prod) {
|
||||
tmp_prod = 1;
|
||||
mapping.push_back(tmp_index);
|
||||
tmp_index++;
|
||||
break;
|
||||
}
|
||||
tmp_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
mapping.push_back(-1);
|
||||
}
|
||||
tmp_index = 0;
|
||||
tmp_prod = 1;
|
||||
if (output_shape.size() >= input_shape.size()) {
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if ((int64_t)mapping[i] == tmp_index) {
|
||||
s.push_back(strategy->GetInputDim()[0][mapping[i]]);
|
||||
if (mapping[i] == tmp_index) {
|
||||
s.push_back(strategy->GetInputDim()[0][LongToSize(mapping[i])]);
|
||||
tmp_index++;
|
||||
} else {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if (mapping[i] == -1) {
|
||||
mapping.push_back(-1);
|
||||
s.push_back(1);
|
||||
} else {
|
||||
for (size_t j = tmp_index; j < input_shape.size(); j++) {
|
||||
tmp_prod *= strategy->GetInputDim()[0][j];
|
||||
tmp_index++;
|
||||
if (mapping[i] == (int64_t)j) {
|
||||
s.push_back(tmp_prod);
|
||||
tmp_prod = 1;
|
||||
break;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if (output_shape[i] < input_shape[LongToSize(tmp_index)]) {
|
||||
break;
|
||||
}
|
||||
for (size_t j = LongToSize(tmp_index); j < input_shape.size(); j++) {
|
||||
tmp_prod *= input_shape[j];
|
||||
if (output_shape[i] == tmp_prod) {
|
||||
tmp_prod = 1;
|
||||
mapping.push_back(tmp_index);
|
||||
tmp_index++;
|
||||
break;
|
||||
}
|
||||
tmp_index++;
|
||||
}
|
||||
}
|
||||
mapping.push_back(-1);
|
||||
tmp_index = 0;
|
||||
tmp_prod = 1;
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if (mapping[i] == -1) {
|
||||
mapping.push_back(-1);
|
||||
s.push_back(1);
|
||||
} else {
|
||||
for (size_t j = tmp_index; j < input_shape.size(); j++) {
|
||||
tmp_prod *= strategy->GetInputDim()[0][j];
|
||||
tmp_index++;
|
||||
if (mapping[i] == (int64_t)j) {
|
||||
s.push_back(tmp_prod);
|
||||
tmp_prod = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -811,12 +805,11 @@ Dimensions PrepareReshapeOutputStrategy(const std::vector<std::shared_ptr<Operat
|
|||
Dimensions PrepareTransposeOutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index) {
|
||||
Dimensions s;
|
||||
|
||||
auto permutation = GetValue<std::vector<int64_t>>(ops[incoming_op_index]->input_value().at(1));
|
||||
auto strategy = ops[incoming_op_index]->selected_strategy();
|
||||
// The strategies are assigned according to the order in permutation (user defined).
|
||||
for (size_t i = 0; i < permutation.size(); i++) {
|
||||
s.push_back(strategy->GetInputDim()[0][permutation[i]]);
|
||||
s.push_back(strategy->GetInputDim()[0][LongToSize(permutation[i])]);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
|
|
@ -190,7 +190,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
}
|
||||
}
|
||||
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, StrategyRec old_str, StrategyRec new_str) {
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, const StrategyRec &old_str, StrategyRec new_str) {
|
||||
for (size_t i = 0; i < op_inputs_num; i++) {
|
||||
if (old_str.inputTensor[i].str_n != 0 && old_str.inputTensor[i].str_c != 0 && old_str.inputTensor[i].str_h != 0 &&
|
||||
old_str.inputTensor[i].str_w != 0) {
|
||||
|
|
|
@ -46,7 +46,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
|
|||
|
||||
Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph);
|
||||
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, StrategyRec old_str, StrategyRec new_str);
|
||||
StrategyRec GetOneLoopStrategy(size_t op_inputs_num, const StrategyRec &old_str, StrategyRec new_str);
|
||||
|
||||
size_t GetDataTypeSize(const TensorType &type);
|
||||
} // namespace parallel
|
||||
|
|
|
@ -48,7 +48,7 @@ bool CheckDeviceConfig(int64_t device_num, int64_t global_rank, const std::strin
|
|||
return false;
|
||||
}
|
||||
// 'device_num_converted' must be divisible by 8
|
||||
if (device_num % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
|
||||
if (LongToSize(device_num) % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
|
||||
MS_LOG(ERROR) << "The context configuration parameter device_num' must be divisible by 8, "
|
||||
"or equal to 1, 2 or 4, but got the value of device_num: "
|
||||
<< device_num;
|
||||
|
@ -275,7 +275,7 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_na
|
|||
RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
|
||||
std::string rank_list_name = FindRankListNameByHashName(hash_name);
|
||||
if (rank_list_name == "WORLD_GROUP") {
|
||||
int64_t device_num = g_device_manager->DeviceNum();
|
||||
int64_t device_num = SizeToLong(g_device_manager->DeviceNum());
|
||||
RankList rank_list;
|
||||
for (size_t i = 0; i < size_t(device_num); ++i) {
|
||||
rank_list.push_back(i);
|
||||
|
|
|
@ -67,7 +67,7 @@ void GraphSplitter::DyeGraph() {
|
|||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
std::for_each(all_nodes.begin(), all_nodes.end(), [this](AnfNodePtr &node) {
|
||||
(void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Mark all nodes with original label at the beginning.
|
||||
node_labels_[node] = default_label_;
|
||||
|
@ -98,17 +98,17 @@ std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
|
|||
auto cnode_split_label = node_labels_[n];
|
||||
// If this node's label is not the same as last node's, create a segment from 'segment_nodes'.
|
||||
if (cnode_split_label != last_label && !segment.nodes.empty()) {
|
||||
results.emplace_back(segment);
|
||||
(void)results.emplace_back(segment);
|
||||
segment.nodes.clear();
|
||||
}
|
||||
// Mark the last label.
|
||||
last_label = cnode_split_label;
|
||||
segment.label = cnode_split_label;
|
||||
segment.nodes.emplace_back(n);
|
||||
(void)segment.nodes.emplace_back(n);
|
||||
}
|
||||
|
||||
// Add the last segment.
|
||||
results.emplace_back(segment);
|
||||
(void)results.emplace_back(segment);
|
||||
MS_LOG(INFO) << "Segments number with different distributed split labels is " << results.size();
|
||||
return results;
|
||||
}
|
||||
|
@ -159,15 +159,15 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
|||
std::vector<AnfNodePtr> concerned_out_degree_nodes = FindInterProcessOutDegree(nodes, comm_edges);
|
||||
|
||||
std::vector<AnfNodePtr> make_tuple_send_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
make_tuple_send_input.insert(make_tuple_send_input.end(), concerned_in_degree_nodes.begin(),
|
||||
concerned_in_degree_nodes.end());
|
||||
(void)make_tuple_send_input.insert(make_tuple_send_input.end(), concerned_in_degree_nodes.begin(),
|
||||
concerned_in_degree_nodes.end());
|
||||
auto make_tuple = func_graph_->NewCNode(make_tuple_send_input);
|
||||
if (concerned_out_degree_nodes.empty()) {
|
||||
std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend)};
|
||||
out.push_back(make_tuple_send_input.back());
|
||||
out.push_back(make_tuple);
|
||||
auto out_node = func_graph_->NewCNode(out);
|
||||
func_graph_->manager()->Replace(func_graph_->output(), out_node);
|
||||
(void)func_graph_->manager()->Replace(func_graph_->output(), out_node);
|
||||
} else {
|
||||
for (auto &recv : concerned_out_degree_nodes) {
|
||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), recv->cast<CNodePtr>()->inputs()[1],
|
||||
|
@ -262,7 +262,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(cons
|
|||
|
||||
InterProcessOpEdge comm_edge = {input_i, cnode};
|
||||
auto comm_node_pair = std::make_tuple(send_node, recv_node, cnode, SizeToInt(i));
|
||||
comm_edges.insert(std::make_pair(comm_edge, comm_node_pair));
|
||||
(void)comm_edges.insert(std::make_pair(comm_edge, comm_node_pair));
|
||||
}
|
||||
return comm_edges;
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeOutputs(con
|
|||
|
||||
InterProcessOpEdge comm_edge = {cnode, user_node};
|
||||
auto comm_node_pair = std::make_tuple(send_node, recv_node, user_node, index);
|
||||
comm_edges.insert(std::make_pair(comm_edge, comm_node_pair));
|
||||
(void)comm_edges.insert(std::make_pair(comm_edge, comm_node_pair));
|
||||
}
|
||||
return comm_edges;
|
||||
}
|
||||
|
@ -369,7 +369,7 @@ std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vecto
|
|||
MS_LOG(INFO) << input_i->fullname_with_scope() << " to " << cnode->fullname_with_scope()
|
||||
<< " is a communication edge.";
|
||||
auto comm_node_pair = comm_edges.at({input_i, cnode});
|
||||
results.emplace_back(std::get<0>(comm_node_pair));
|
||||
(void)results.emplace_back(std::get<0>(comm_node_pair));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
|
@ -396,7 +396,7 @@ std::vector<AnfNodePtr> GraphSplitter::FindInterProcessOutDegree(const std::vect
|
|||
MS_LOG(INFO) << cnode->fullname_with_scope() << " to " << user_node->fullname_with_scope()
|
||||
<< " is a communication edge.";
|
||||
auto comm_node_pair = comm_edges.at({cnode, user_node});
|
||||
results.emplace_back(std::get<1>(comm_node_pair));
|
||||
(void)results.emplace_back(std::get<1>(comm_node_pair));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
|
|
|
@ -108,8 +108,7 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
|
|||
cnode->AddPrimalAttr(IN_STRATEGY, strategy);
|
||||
}
|
||||
|
||||
CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map) {
|
||||
CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const NodeUsersMap &node_users_map) {
|
||||
// Recursively find micro tags, this may takes much more time if layers are too much
|
||||
std::queue<AnfNodePtr> visited;
|
||||
visited.push(node_user);
|
||||
|
@ -158,14 +157,14 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
|
|||
ValuePtr micro = nullptr;
|
||||
int64_t step = 0;
|
||||
if (grad_accumulation_shard) {
|
||||
auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, manager, node_user_map);
|
||||
auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, node_user_map);
|
||||
if (cnode_with_micro_size && cnode_with_micro_size->HasPrimalAttr(MICRO)) {
|
||||
micro = cnode_with_micro_size->GetPrimalAttr(MICRO);
|
||||
step = GetValue<int64_t>(micro);
|
||||
}
|
||||
}
|
||||
args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
|
||||
args2 = MakeValue(param_ptr->param_info()->comm_fusion() + step * PIPELINE_FUSTION_OFFSET);
|
||||
args2 = MakeValue(LongToSize(param_ptr->param_info()->comm_fusion()) + LongToSize(step) * PIPELINE_FUSTION_OFFSET);
|
||||
OperatorAttrs attrs = {};
|
||||
auto py_instance = CreateOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
|
@ -175,7 +174,7 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
|
|||
auto attrs_prim = new_prim->attrs();
|
||||
attrs_prim[GROUP] = args1;
|
||||
attrs_prim[kAttrFusion] = args2;
|
||||
new_prim->SetAttrs(attrs_prim);
|
||||
(void)new_prim->SetAttrs(attrs_prim);
|
||||
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
|
@ -252,11 +251,11 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
// If the graph likes the followings:
|
||||
// 1. MicroStepAllGather->MirrorMicro->load, we need to visit the param after the load
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindNextNode(const std::pair<AnfNodePtr, int> &node_ptr,
|
||||
const FuncGraphPtr &root, const NodeUsersMap &node_users_map) {
|
||||
const NodeUsersMap &node_users_map) {
|
||||
std::vector<std::pair<AnfNodePtr, int>> to_be_visited_set;
|
||||
if (!IsPrimitiveCNode(node_ptr.first, prim::kPrimMirrorMicroStep) &&
|
||||
!IsPrimitiveCNode(node_ptr.first, prim::kPrimMicroStepAllGather)) {
|
||||
to_be_visited_set.emplace_back(node_ptr);
|
||||
(void)to_be_visited_set.emplace_back(node_ptr);
|
||||
return to_be_visited_set;
|
||||
}
|
||||
auto node_set = node_users_map.at(node_ptr.first);
|
||||
|
@ -269,7 +268,7 @@ std::vector<std::pair<AnfNodePtr, int>> FindNextNode(const std::pair<AnfNodePtr,
|
|||
visited.pop();
|
||||
if (!IsPrimitiveCNode(node.first, prim::kPrimMirrorMicroStep) &&
|
||||
!IsPrimitiveCNode(node.first, prim::kPrimMicroStepAllGather)) {
|
||||
to_be_visited_set.emplace_back(node);
|
||||
(void)to_be_visited_set.emplace_back(node);
|
||||
} else {
|
||||
auto next_node_set = node_users_map.at(node.first);
|
||||
for (auto &node_user : next_node_set) {
|
||||
|
@ -296,7 +295,7 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) {
|
|||
if (IsPrimitiveCNode(temp_node.first, prim::kPrimCast)) {
|
||||
temp_node = *node_users_map[temp_node.first].begin();
|
||||
}
|
||||
auto node_set = FindNextNode(temp_node, root, node_users_map);
|
||||
auto node_set = FindNextNode(temp_node, node_users_map);
|
||||
for (auto &node_user : node_set) {
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
|
||||
}
|
||||
|
|
|
@ -245,12 +245,12 @@ Status CumSumInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra.at(0);
|
||||
if (input_strategy.size() <= IntToSize(axis_)) {
|
||||
if (input_strategy.size() <= LongToSize(axis_)) {
|
||||
MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to "
|
||||
<< axis_;
|
||||
return FAILED;
|
||||
}
|
||||
auto axis_split = input_strategy[axis_];
|
||||
auto axis_split = input_strategy[LongToSize(axis_)];
|
||||
if (axis_split > 1) {
|
||||
MS_LOG(ERROR) << "Currently, CumSum does not support the sharding strategies which splits axis.";
|
||||
return FAILED;
|
||||
|
@ -261,11 +261,11 @@ Status CumSumInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
std::vector<StrategyPtr> CumSumInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
if (axis_ < 0 || IntToSize(axis_) >= inputs_shape_[0].size()) {
|
||||
if (axis_ < 0 || LongToSize(axis_) >= inputs_shape_[0].size()) {
|
||||
MS_LOG(EXCEPTION) << "Wrong axis value: " << axis_;
|
||||
}
|
||||
// Currently, CumSum does not support the sharding strategies which splits axis.
|
||||
input0_split[axis_] = 0;
|
||||
input0_split[LongToSize(axis_)] = 0;
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
|
|
|
@ -406,7 +406,7 @@ Status GatherInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (param_strategy[axis_] == 1) {
|
||||
if (param_strategy[LongToSize(axis_)] == 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis is not split, can not set output strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -495,7 +495,7 @@ Status GatherInfo::InferDevMatrixShape() {
|
|||
|
||||
// param_strategy(axis) is 1
|
||||
if (param_strategy.at(LongToSize(axis_)) == 1) {
|
||||
dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
|
||||
(void)dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
|
||||
}
|
||||
|
||||
// infer out dev_matrix_shape
|
||||
|
@ -536,7 +536,7 @@ void GatherInfo::InferInputsTensorMap() {
|
|||
Shape tensor_map_params;
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
if (param_strategy.at(LongToSize(axis_)) != 1) {
|
||||
tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
|
||||
(void)tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
|
||||
for (size_t i = 0; i < param_size; ++i) {
|
||||
tensor_map_params.push_back(SizeToLong(param_size - i - 1));
|
||||
}
|
||||
|
@ -549,8 +549,8 @@ void GatherInfo::InferInputsTensorMap() {
|
|||
tensor_map_index.push_back(SizeToLong(index_size - i - 1));
|
||||
}
|
||||
}
|
||||
inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
|
||||
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||
}
|
||||
|
||||
Shape GatherInfo::InferOutputsTensorMapSplitAxis() {
|
||||
|
@ -560,18 +560,18 @@ Shape GatherInfo::InferOutputsTensorMapSplitAxis() {
|
|||
if (axis_ == 0) {
|
||||
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
|
||||
// the output is repeat calculation
|
||||
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
||||
(void)tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
||||
} else {
|
||||
tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
|
||||
(void)tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
|
||||
}
|
||||
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
|
||||
(void)tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
|
||||
for (size_t i = 1; i < param_size; ++i) {
|
||||
tensor_map_out.push_back(param_size - 1 - i);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < param_size; ++i) {
|
||||
if (i == LongToSize(axis_)) {
|
||||
tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
|
||||
(void)tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
|
||||
} else {
|
||||
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
|
||||
tensor_map_out.push_back(MAP_NONE);
|
||||
|
|
|
@ -1390,8 +1390,8 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index) {
|
||||
auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
const auto &input_info = (*it)->inputs_ptr[input_index];
|
||||
|
@ -1401,8 +1401,8 @@ TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(StrategyPtr stra, siz
|
|||
return empty;
|
||||
}
|
||||
|
||||
TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index) {
|
||||
auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
const auto &output_info = (*it)->outputs_ptr[output_index];
|
||||
|
@ -1412,8 +1412,8 @@ TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, si
|
|||
return empty;
|
||||
}
|
||||
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index) {
|
||||
auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
|
||||
return swc->inputs_ptr[input_index].tensor_layout() == input_layout;
|
||||
};
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
|
@ -1423,8 +1423,8 @@ StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(TensorLayout input_lay
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index) {
|
||||
auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
|
||||
return swc->outputs_ptr[output_index].tensor_layout() == output_layout;
|
||||
};
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
|
@ -1434,14 +1434,14 @@ StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_l
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool OperatorInfo::IsReshape() {
|
||||
bool OperatorInfo::IsReshape() const {
|
||||
if (name_.find(RESHAPEINFO) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OperatorInfo::IsTmpIdentity() {
|
||||
bool OperatorInfo::IsTmpIdentity() const {
|
||||
if (name_.find(IDENTITY_INFO) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -147,12 +147,12 @@ class OperatorInfo {
|
|||
StrategyPtr selected_strategy() const { return selected_strategy_; }
|
||||
CostPtr selected_cost() const { return selected_cost_; }
|
||||
|
||||
TensorLayout GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index);
|
||||
TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index);
|
||||
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
|
||||
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
|
||||
bool IsReshape();
|
||||
bool IsTmpIdentity();
|
||||
TensorLayout GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index);
|
||||
TensorLayout GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index);
|
||||
StrategyPtr GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index);
|
||||
StrategyPtr GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index);
|
||||
bool IsReshape() const;
|
||||
bool IsTmpIdentity() const;
|
||||
|
||||
void set_swc_index(int64_t, int64_t);
|
||||
int64_t swc_index() { return swc_index_; }
|
||||
|
|
|
@ -150,7 +150,8 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
|||
int64_t shape_dim = 2;
|
||||
auto value = replace_op_.front().second.second.front().first.second;
|
||||
Shape dst_shape = GetValue<std::vector<int64_t>>(value);
|
||||
Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode_->input(shape_dim)->cast<ValueNodePtr>()->value());
|
||||
Shape origin_dst_shape =
|
||||
GetValue<std::vector<int64_t>>(cnode_->input(LongToSize(shape_dim))->cast<ValueNodePtr>()->value());
|
||||
if (dst_shape.size() == origin_dst_shape.size()) {
|
||||
for (size_t i = 0; i < dst_shape.size(); ++i) {
|
||||
if (origin_dst_shape[i] != dst_shape[i] && origin_dst_shape[i] != -1) {
|
||||
|
@ -602,12 +603,12 @@ int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &in
|
|||
return index_comm[0].first;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) {
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) const {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
const auto &swc = strategy_cost_[LongToSize(swc_index)];
|
||||
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
|
||||
return true;
|
||||
}
|
||||
|
@ -617,12 +618,12 @@ bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, cons
|
|||
return false;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) {
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) const {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
const auto &swc = strategy_cost_[LongToSize(swc_index)];
|
||||
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
|
||||
return true;
|
||||
}
|
||||
|
@ -632,19 +633,19 @@ bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const
|
|||
return false;
|
||||
}
|
||||
|
||||
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) {
|
||||
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) const {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
const auto &swc = strategy_cost_[LongToSize(swc_index)];
|
||||
return std::move(swc->inputs_ptr[0].tensor_layout());
|
||||
}
|
||||
|
||||
TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) {
|
||||
TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) const {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
const auto &swc = strategy_cost_[LongToSize(swc_index)];
|
||||
return std::move(swc->outputs_ptr[0].tensor_layout());
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -73,11 +73,11 @@ class ReshapeInfo : public OperatorInfo {
|
|||
int64_t GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &) const;
|
||||
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &) const;
|
||||
|
||||
TensorLayout GetInputLayoutBySWCIndex(int64_t);
|
||||
TensorLayout GetOutputLayoutBySWCIndex(int64_t);
|
||||
TensorLayout GetInputLayoutBySWCIndex(int64_t) const;
|
||||
TensorLayout GetOutputLayoutBySWCIndex(int64_t) const;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -850,7 +850,7 @@ void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfN
|
|||
if (!send_depend) {
|
||||
continue;
|
||||
}
|
||||
send_ops->insert(send_ops->begin(), send_depend);
|
||||
(void)send_ops->insert(send_ops->begin(), send_depend);
|
||||
continue;
|
||||
}
|
||||
if (Reuse(node, user_node_stage, *send_ops, DEST_RANK)) {
|
||||
|
|
|
@ -302,7 +302,7 @@ void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
|
|||
void AddOperatorToIgnoreCandidates(const PrimitivePtr &prim, const OperatorInfoPtr &operator_info) {
|
||||
if (prim->name() == CAST) {
|
||||
// add CAST into ignore_candidate
|
||||
ignore_candidate_.insert(operator_info);
|
||||
(void)ignore_candidate_.insert(operator_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -374,7 +374,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
bool retGenStra;
|
||||
Status retGenStra;
|
||||
if (AttrFound(attrs, STRATEGY_GEN_MODE) && GetValue<std::string>(attrs[STRATEGY_GEN_MODE]) == DATA_PARALLEL) {
|
||||
MS_LOG(INFO) << "generating batch parallel strategy...";
|
||||
StrategyPtr strategyPtr = parallel::GenerateBatchParallelStrategy(operator_info, prim);
|
||||
|
@ -1076,7 +1076,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
PrintStrategy(s_strategy);
|
||||
}
|
||||
// Remove some operatorInfo from the CNODEs
|
||||
IgnoreOperatorsInCostGraph();
|
||||
(void)IgnoreOperatorsInCostGraph();
|
||||
|
||||
ops_in_a_loop_.clear();
|
||||
configured_stra_ops_.clear();
|
||||
|
|
|
@ -215,7 +215,7 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
if (next_node_dtype) {
|
||||
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
|
||||
<< " communication.";
|
||||
pre_node_ = CreateFP16Cast(node, pre_node, node_user_map, next_node_dtype);
|
||||
pre_node_ = CreateFP16Cast(node, pre_node, next_node_dtype);
|
||||
}
|
||||
node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name);
|
||||
} else {
|
||||
|
@ -746,10 +746,10 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
if (reshape_type_str.find(BOOL) != std::string::npos) {
|
||||
auto cast_int = CreateCastOp(kInt32);
|
||||
auto cast_bool = CreateCastOp(kBool);
|
||||
replace_op.insert(replace_op.begin(), cast_int);
|
||||
replace_op.insert(replace_op.end(), cast_bool);
|
||||
replace_op_info.insert(replace_op_info.begin(), {false, 1});
|
||||
replace_op_info.insert(replace_op_info.end(), {false, 1});
|
||||
(void)replace_op.insert(replace_op.begin(), cast_int);
|
||||
(void)replace_op.insert(replace_op.end(), cast_bool);
|
||||
(void)replace_op_info.insert(replace_op_info.begin(), {false, 1});
|
||||
(void)replace_op_info.insert(replace_op_info.end(), {false, 1});
|
||||
}
|
||||
|
||||
// step2:traverse op_list and insert node
|
||||
|
|
|
@ -92,7 +92,6 @@ AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr
|
|||
if (!user_node->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
auto op_info = user_node->user_data<OperatorInfo>();
|
||||
auto tensor_info = GetInputsTensorInfo(node_user);
|
||||
if (is_first_tensor_info) {
|
||||
is_first_tensor_info = false;
|
||||
|
@ -111,7 +110,8 @@ AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr
|
|||
}
|
||||
|
||||
bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) {
|
||||
return std::any_of(check_list.begin(), check_list.end(), [cnode](string in) { return IsSomePrimitive(cnode, in); });
|
||||
return std::any_of(check_list.begin(), check_list.end(),
|
||||
[cnode](const string &in) { return IsSomePrimitive(cnode, in); });
|
||||
}
|
||||
|
||||
bool IsParallelCareNode(const CNodePtr &cnode) {
|
||||
|
@ -254,8 +254,8 @@ RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
|
|||
is_first_group = false;
|
||||
} else {
|
||||
std::vector<int64_t> new_comm_group_list;
|
||||
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
|
||||
std::back_inserter(new_comm_group_list));
|
||||
(void)std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(),
|
||||
group_list.end(), std::back_inserter(new_comm_group_list));
|
||||
common_group_list = new_comm_group_list;
|
||||
}
|
||||
}
|
||||
|
@ -456,8 +456,7 @@ TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMa
|
|||
// Create a cast node given the current node and the previous node. The target type of the the cast is from the
|
||||
// compute_node_type.
|
||||
// Return the new cast node with pre_node as the inputs.
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||
const TypePtr &compute_node_type) {
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type) {
|
||||
const char kOpsFunctionModelName[] = "mindspore.ops.functional";
|
||||
static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
|
||||
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
|
||||
|
@ -525,7 +524,7 @@ void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
if (cast_input->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator";
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
|
||||
prim->AddAttr("recompute", MakeValue(false));
|
||||
(void)prim->AddAttr("recompute", MakeValue(false));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,8 +43,7 @@ AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr
|
|||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||
const TypePtr &compute_node_type);
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type);
|
||||
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
void LabelGenMaskMicro(const FuncGraphPtr &root);
|
||||
|
|
|
@ -215,7 +215,7 @@ Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map, con
|
|||
}
|
||||
straspb::ParallelGroupRanks *ckpt_restore_rank_list = parallel_group_map.mutable_ckpt_restore_rank_list();
|
||||
for (auto &restore_rank : restore_rank_list) {
|
||||
ckpt_restore_rank_list->add_dim(restore_rank);
|
||||
ckpt_restore_rank_list->add_dim(LongToSize(restore_rank));
|
||||
}
|
||||
if (!CheckPath(group_info_save_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
|
|
|
@ -26,6 +26,7 @@ _MAX_GROUP_NAME_LEN = 127
|
|||
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
||||
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
|
||||
|
||||
|
||||
class _ParallelFusionConfig:
|
||||
"""
|
||||
The key of the Parallel fusion method configuration.
|
||||
|
@ -37,6 +38,7 @@ class _ParallelFusionConfig:
|
|||
INDEX = "index"
|
||||
SIZE = "size"
|
||||
|
||||
|
||||
class _ParallelOptimizerConfig:
|
||||
"""
|
||||
The key of the Parallel Optimizer. There are three
|
||||
|
@ -697,8 +699,8 @@ class _AutoParallelContext:
|
|||
self._context_handle.set_grad_accumulation_shard(
|
||||
parallel_optimizer_config[grad_shard_name])
|
||||
|
||||
|
||||
def get_grad_accumulation_shard(self):
|
||||
"""Get grad accumulation shard."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_grad_accumulation_shard()
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ class _MpiConfig:
|
|||
|
||||
@property
|
||||
def enable_mpi(self):
|
||||
"""Get enable mpi."""
|
||||
return self._mpiconfig_handle.get_enable_mpi()
|
||||
|
||||
@enable_mpi.setter
|
||||
|
|
Loading…
Reference in New Issue