!11472 support to checkpoint group info

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-01-21 17:11:23 +08:00 committed by Gitee
commit 0641940b87
16 changed files with 184 additions and 27 deletions

View File

@ -124,6 +124,10 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
}
void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
group_ckpt_save_file_ = group_ckpt_save_file;
}
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices;
}

View File

@ -102,6 +102,8 @@ class ParallelContext {
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
@ -132,6 +134,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
std::string group_ckpt_save_file_;
bool enable_parallel_optimizer_;
};

View File

@ -83,6 +83,7 @@ class DeviceManager {
void Clear();
std::string world_group() const { return gm_.world_group(); }
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
std::string FindRankListNameByHashName(const std::string &hash_name);
private:

View File

@ -40,16 +40,16 @@ class DynCreator {
public:
~DynCreator() = default;
// creat static singleton dyn_creator instance
// create static singleton dyn_creator instance
static DynCreator &Instance() {
static DynCreator fac = DynCreator();
return fac;
}
// register
void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
// creator
OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
const PrimitiveAttrs &attrs, size_t count) {
OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
const PrimitiveAttrs &attrs, size_t count) {
std::string op_name = name + std::to_string(count);
auto iter = Function_map_.find(name);
if (iter == Function_map_.end()) {
@ -67,7 +67,7 @@ class DynCreator {
class RegisterAction {
public:
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
DynCreator::Instance().Regist(name, creatfn);
DynCreator::Instance().Register(name, creatfn);
}
~RegisterAction() = default;

View File

@ -17,6 +17,7 @@
#include "frontend/parallel/group_manager.h"
#include <algorithm>
#include <vector>
#include <utility>
#include "backend/session/executor_manager.h"
#include "frontend/parallel/device_manager.h"
#include "utils/comm_manager.h"
@ -109,6 +110,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
return Status::FAILED;
}
std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
group_info_.push_back(group_info);
MS_LOG(INFO) << "Create group success, group name is " << group_name;
return Status::SUCCESS;
}
@ -187,5 +191,27 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro
}
void GroupManager::Clear() { (void)DestroyAllGroups(); }
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
// Create group through the executor
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
for (auto &group : group_info) {
bool ret = executor->CreateCommGroup(group.first, group.second);
if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
return FAILED;
}
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "frontend/parallel/device.h"
#include "frontend/parallel/status.h"
@ -62,6 +63,7 @@ class GroupManager {
Status FindGroup(const std::string &name, Group **group);
std::string world_group() const { return world_group_; }
void set_world_group(const std::string &name) { world_group_ = name; }
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return group_info_; }
void Clear();
private:
@ -69,7 +71,10 @@ class GroupManager {
// the key is group name (name_)
std::map<std::string, Group> groups_;
std::string world_group_;
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info_;
};
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info);
} // namespace parallel
} // namespace mindspore

View File

@ -160,7 +160,7 @@ Status ReduceMethod::InferForwardCommunication() {
Shape group_creat_map;
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map.
// it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
}
@ -200,12 +200,12 @@ Status ReduceMethod::InferForwardCommunication() {
}
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
// Creat AllReduceSum op
// Create AllReduceSum op
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
std::string group_name = forward_group[0].name();
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
// Creat RealDiv op
// Create RealDiv op
OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size());
@ -237,7 +237,7 @@ Status ReduceMeanInfo::InferForwardCommunication() {
Shape group_creat_map;
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map.
// it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
}

View File

@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
std::string instance_name_base = FORWARD_OP;
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
MS_EXCEPTION_IF_NULL(forward_node);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
if (pos >= SizeToLong(node->inputs().size())) {
MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
}
// Creat new node
// Create new node
AnfNodePtr target_node = node->input(LongToSize(pos));
MS_EXCEPTION_IF_NULL(target_node);
// Creat instance_name
// Create instance_name
auto op = (redistribution_oplist_ptr->first)[index];
std::string op_name = (redistribution_oplist_ptr->first)[index].first;
std::string instance_name_base = REDISTRIBUTION_OP;
@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
<< instance_name;
}
// Creat new node
// Create new node
AnfNodePtr pre_node = node->input(LongToSize(pos));
MS_EXCEPTION_IF_NULL(pre_node);
InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
CNodePtr insert_node_new;
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
return;
}
if (IsValueNode<Primitive>(node->input(0))) {
@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
}
// Sovle the input order
// Solve the input order
// For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
// The Original code here will bind the all operations to the first inputs of theses operatos
// However, the segment_sum operation needs two inputs, To sovle this
// The Original code here will bind the all operations to the first inputs of these operatos
// However, the segment_sum operation needs two inputs, To solve this
// We maintain a dict to count the times of the same operations,
// and bind the inputs according to the times of the op appears.
static std::unordered_map<AnfNodePtr, int> input_map = {};
@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
}
}
OperatorInfoPtr operator_ =
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
(OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
if (operator_ == nullptr) {
MS_LOG(INFO) << "Creat " << name << " failed";
MS_LOG(INFO) << "Create " << name << " failed";
return nullptr;
}
std::string origin_name = operator_->name();
@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
if (IsInBatchParallelBlackList(prim)) {
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
}
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_);
}
@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
}
if (cnode->input(0)->isa<CNode>()) {
if (cnode->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2";
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
}
base_shape_ptr = cnode->input(1)->Shape();
}
@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
bool has_backward = !sens_loss_pairs.empty();
// split sens must before inserting the operators.
for (auto &pair : sens_loss_pairs) {
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
if (IsLastStage()) {
StepSplitSens(pair);
@ -2703,7 +2703,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
auto param_split_shapes = gatherv2_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets();
if (param_split_shapes.size() != index_offsets.size()) {
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
}
std::vector<std::pair<int64_t, int64_t>> manual_shape;
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
}
}
}
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
}
}
bool CreateGroupsByCkptFile(const std::string &file) {
GroupInfoMap group_info_map;
if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
return false;
}
if (CreateGroups(group_info_map) != SUCCESS) {
return false;
}
MS_LOG(INFO) << "Create groups by checkpoint file success";
return true;
}
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(parameter);
@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager);
auto group_info = g_device_manager->group_info();
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
}
DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once

View File

@ -109,7 +109,7 @@ void CoverSliceShape(const FuncGraphPtr &root);
void SetVirtualDatasetStrategy(const CNodePtr &node);
// Creat parallel operator for primitive node(has strategy)
// Create parallel operator for primitive node(has strategy)
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true);
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
@ -163,6 +163,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr
void SetLastNodeStrategy(const StrategyPtr strategyPtr);
bool CreateGroupsByCkptFile(const std::string &file);
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
} // namespace parallel
} // namespace mindspore

View File

@ -34,6 +34,8 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
}
return instance;
}
@ -46,6 +48,39 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
return false;
}
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
MS_EXCEPTION_IF_NULL(group_info_map);
if (!CheckPointExit(file)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
straspb::ParallelGroupMap parallel_group_map;
std::fstream input(file, std::ios::in | std::ios::binary);
if (!parallel_group_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed";
return FAILED;
}
input.close();
size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
for (size_t i = 0; i < group_num; ++i) {
straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i));
std::string group_name = parallel_group_item.group_name();
straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
std::vector<uint32_t> ranks;
for (size_t j = 0; j < rank_num; ++j) {
uint32_t rank = parallel_group_ranks.dim(SizeToLong(j));
ranks.push_back(rank);
}
std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks);
group_info_map->push_back(group);
}
return SUCCESS;
}
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (strategy_map == nullptr) {
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
@ -141,5 +176,27 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
output.close();
return SUCCESS;
}
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
straspb::ParallelGroupMap parallel_group_map;
for (auto &group : group_info_map) {
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
MS_EXCEPTION_IF_NULL(parallel_group_item);
parallel_group_item->set_group_name(group.first);
straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks();
MS_EXCEPTION_IF_NULL(parallel_group_ranks);
for (auto &rank : group.second) {
parallel_group_ranks->add_dim(rank);
}
}
std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_group_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
}
output.close();
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -32,6 +32,7 @@ namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpoint {
public:
StrategyCheckpoint() {
@ -40,11 +41,16 @@ class StrategyCheckpoint {
load_checkpoint_on_ = false;
save_file_ = "";
save_checkpoint_on_ = false;
group_info_save_file_ = "";
group_info_save_on_ = false;
}
~StrategyCheckpoint() = default;
Status Load(StrategyMap *strategy_map);
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
bool group_info_save_on() const { return group_info_save_on_; }
static StrategyCheckpoint &GetInstance();
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
@ -57,6 +63,8 @@ class StrategyCheckpoint {
bool save_checkpoint_on_;
bool CheckPointExit(const std::string path) const;
int64_t current_stage_;
std::string group_info_save_file_;
bool group_info_save_on_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -157,6 +157,7 @@ PYBIND11_MODULE(_c_expression, m) {
"Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.")
.def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
"Set pipeline stage split num.")
.def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")

View File

@ -61,6 +61,19 @@ message ParallelLayoutItem {
required ParallelLayouts parallel_layouts = 2;
}
message ParallelGroupRanks {
repeated uint32 dim = 1;
}
message ParallelGroupItem {
required string group_name = 1;
required ParallelGroupRanks parallel_group_ranks = 2;
}
message ParallelGroupMap {
repeated ParallelGroupItem parallel_group_item = 1;
}
message ParallelStrategyMap {
required uint32 current_stage = 1;
repeated ParallelStrategyItem parallel_strategy_item = 2;

View File

@ -283,6 +283,15 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_save_file()
def set_group_ckpt_save_file(self, group_ckpt_save_file):
"""Set group checkpoint save path."""
self.check_context_handle()
import os
dir_path = os.path.dirname(group_ckpt_save_file)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
def get_parameter_broadcast_is_set(self):
"""Get parameter broadcast is set or not."""
self.check_context_handle()
@ -505,6 +514,7 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
@ -533,7 +543,7 @@ _get_auto_parallel_context_func_map = {
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list)
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
def _set_auto_parallel_context(**kwargs):
"""
@ -574,6 +584,7 @@ def _set_auto_parallel_context(**kwargs):
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.

View File

@ -31,5 +31,9 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) { return SUCCESS; }
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
} // namespace parallel
} // namespace mindspore

View File

@ -75,7 +75,8 @@ def test_six_matmul_save():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
group_ckpt_save_file="./group_stage1.ckpt")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((1, 8), (8, 1))
strategy3 = ((2, 2), (2, 2))
@ -137,7 +138,8 @@ def test_six_matmul_load():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
group_ckpt_save_file="./group_stage1.ckpt")
strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1))