forked from mindspore-Ecosystem/mindspore
!11472 support to checkpoint group info
From: @yangzhenzhang Reviewed-by: @stsuteng,@kisnwang Signed-off-by: @stsuteng
This commit is contained in:
commit
0641940b87
|
@ -124,6 +124,10 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
|
|||
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
|
||||
}
|
||||
|
||||
void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
|
||||
group_ckpt_save_file_ = group_ckpt_save_file;
|
||||
}
|
||||
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
|
|
|
@ -102,6 +102,8 @@ class ParallelContext {
|
|||
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
|
||||
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
|
||||
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
|
||||
void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
|
||||
std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
|
||||
|
||||
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
|
||||
enable_parallel_optimizer_ = enable_parallel_optimizer;
|
||||
|
@ -132,6 +134,7 @@ class ParallelContext {
|
|||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||
std::string strategy_ckpt_load_file_;
|
||||
std::string strategy_ckpt_save_file_;
|
||||
std::string group_ckpt_save_file_;
|
||||
bool enable_parallel_optimizer_;
|
||||
};
|
||||
|
||||
|
|
|
@ -83,6 +83,7 @@ class DeviceManager {
|
|||
|
||||
void Clear();
|
||||
std::string world_group() const { return gm_.world_group(); }
|
||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
|
||||
std::string FindRankListNameByHashName(const std::string &hash_name);
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,16 +40,16 @@ class DynCreator {
|
|||
public:
|
||||
~DynCreator() = default;
|
||||
|
||||
// creat static singleton dyn_creator instance
|
||||
// create static singleton dyn_creator instance
|
||||
static DynCreator &Instance() {
|
||||
static DynCreator fac = DynCreator();
|
||||
return fac;
|
||||
}
|
||||
// register
|
||||
void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
|
||||
void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
|
||||
// creator
|
||||
OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
|
||||
const PrimitiveAttrs &attrs, size_t count) {
|
||||
OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
|
||||
const PrimitiveAttrs &attrs, size_t count) {
|
||||
std::string op_name = name + std::to_string(count);
|
||||
auto iter = Function_map_.find(name);
|
||||
if (iter == Function_map_.end()) {
|
||||
|
@ -67,7 +67,7 @@ class DynCreator {
|
|||
class RegisterAction {
|
||||
public:
|
||||
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
|
||||
DynCreator::Instance().Regist(name, creatfn);
|
||||
DynCreator::Instance().Register(name, creatfn);
|
||||
}
|
||||
~RegisterAction() = default;
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "frontend/parallel/group_manager.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
|
@ -109,6 +110,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
|
|||
return Status::FAILED;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
|
||||
group_info_.push_back(group_info);
|
||||
|
||||
MS_LOG(INFO) << "Create group success, group name is " << group_name;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
@ -187,5 +191,27 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro
|
|||
}
|
||||
|
||||
void GroupManager::Clear() { (void)DestroyAllGroups(); }
|
||||
|
||||
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
|
||||
// Create group through the executor
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
|
||||
MS_EXCEPTION_IF_NULL(executor);
|
||||
|
||||
for (auto &group : group_info) {
|
||||
bool ret = executor->CreateCommGroup(group.first, group.second);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "frontend/parallel/device.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
|
@ -62,6 +63,7 @@ class GroupManager {
|
|||
Status FindGroup(const std::string &name, Group **group);
|
||||
std::string world_group() const { return world_group_; }
|
||||
void set_world_group(const std::string &name) { world_group_ = name; }
|
||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return group_info_; }
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
|
@ -69,7 +71,10 @@ class GroupManager {
|
|||
// the key is group name (name_)
|
||||
std::map<std::string, Group> groups_;
|
||||
std::string world_group_;
|
||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info_;
|
||||
};
|
||||
|
||||
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ Status ReduceMethod::InferForwardCommunication() {
|
|||
Shape group_creat_map;
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||
// it need to handle the first dimention of map.
|
||||
// it need to handle the first dimension of map.
|
||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
||||
}
|
||||
|
@ -200,12 +200,12 @@ Status ReduceMethod::InferForwardCommunication() {
|
|||
}
|
||||
|
||||
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
|
||||
// Creat AllReduceSum op
|
||||
// Create AllReduceSum op
|
||||
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
|
||||
std::string group_name = forward_group[0].name();
|
||||
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
|
||||
|
||||
// Creat RealDiv op
|
||||
// Create RealDiv op
|
||||
OperatorName operator1_name = REAL_DIV;
|
||||
std::vector<Device> device_list = forward_group[0].GetDevicesList();
|
||||
auto divisor = static_cast<float>(device_list.size());
|
||||
|
@ -237,7 +237,7 @@ Status ReduceMeanInfo::InferForwardCommunication() {
|
|||
Shape group_creat_map;
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||
// it need to handle the first dimention of map.
|
||||
// it need to handle the first dimension of map.
|
||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
||||
}
|
||||
|
|
|
@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
|||
std::string instance_name_base = FORWARD_OP;
|
||||
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
|
||||
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
|
||||
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode
|
||||
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
|
||||
MS_EXCEPTION_IF_NULL(forward_node);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
|
@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
|
|||
if (pos >= SizeToLong(node->inputs().size())) {
|
||||
MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
|
||||
}
|
||||
// Creat new node
|
||||
// Create new node
|
||||
AnfNodePtr target_node = node->input(LongToSize(pos));
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
// Creat instance_name
|
||||
// Create instance_name
|
||||
auto op = (redistribution_oplist_ptr->first)[index];
|
||||
std::string op_name = (redistribution_oplist_ptr->first)[index].first;
|
||||
std::string instance_name_base = REDISTRIBUTION_OP;
|
||||
|
@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
|
|||
MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
|
||||
<< instance_name;
|
||||
}
|
||||
// Creat new node
|
||||
// Create new node
|
||||
AnfNodePtr pre_node = node->input(LongToSize(pos));
|
||||
MS_EXCEPTION_IF_NULL(pre_node);
|
||||
InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
|
||||
|
@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
CNodePtr insert_node_new;
|
||||
|
||||
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
|
||||
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
||||
MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
|
||||
return;
|
||||
}
|
||||
if (IsValueNode<Primitive>(node->input(0))) {
|
||||
|
@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
|||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
||||
}
|
||||
// Sovle the input order
|
||||
// Solve the input order
|
||||
// For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
|
||||
// The Original code here will bind the all operations to the first inputs of theses operatos
|
||||
// However, the segment_sum operation needs two inputs, To sovle this
|
||||
// The Original code here will bind the all operations to the first inputs of these operatos
|
||||
// However, the segment_sum operation needs two inputs, To solve this
|
||||
// We maintain a dict to count the times of the same operations,
|
||||
// and bind the inputs according to the times of the op appears.
|
||||
static std::unordered_map<AnfNodePtr, int> input_map = {};
|
||||
|
@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
|
|||
}
|
||||
}
|
||||
OperatorInfoPtr operator_ =
|
||||
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
||||
(OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
||||
if (operator_ == nullptr) {
|
||||
MS_LOG(INFO) << "Creat " << name << " failed";
|
||||
MS_LOG(INFO) << "Create " << name << " failed";
|
||||
return nullptr;
|
||||
}
|
||||
std::string origin_name = operator_->name();
|
||||
|
@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
|
|||
if (IsInBatchParallelBlackList(prim)) {
|
||||
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
|
||||
}
|
||||
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
|
||||
MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
|
||||
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
||||
MS_EXCEPTION_IF_NULL(operator_);
|
||||
}
|
||||
|
@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
|||
}
|
||||
if (cnode->input(0)->isa<CNode>()) {
|
||||
if (cnode->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2";
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
base_shape_ptr = cnode->input(1)->Shape();
|
||||
}
|
||||
|
@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
bool has_backward = !sens_loss_pairs.empty();
|
||||
// split sens must before inserting the operators.
|
||||
for (auto &pair : sens_loss_pairs) {
|
||||
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
|
||||
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
|
||||
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
|
||||
if (IsLastStage()) {
|
||||
StepSplitSens(pair);
|
||||
|
@ -2703,7 +2703,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
||||
auto index_offsets = gatherv2_info->index_offsets();
|
||||
if (param_split_shapes.size() != index_offsets.size()) {
|
||||
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
|
||||
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
|
||||
}
|
||||
std::vector<std::pair<int64_t, int64_t>> manual_shape;
|
||||
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
|
||||
|
@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
||||
}
|
||||
|
@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
bool CreateGroupsByCkptFile(const std::string &file) {
|
||||
GroupInfoMap group_info_map;
|
||||
if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (CreateGroups(group_info_map) != SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Create groups by checkpoint file success";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
|
@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
// ForwardCommunication BackwardCommunication TensorRedistribution
|
||||
ParallelCommunication(root, all_nodes, manager);
|
||||
|
||||
auto group_info = g_device_manager->group_info();
|
||||
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
||||
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||
}
|
||||
|
||||
DumpGraph(root, std::string(STEP_PARALLEL_END));
|
||||
|
||||
// step parallel only run once
|
||||
|
|
|
@ -109,7 +109,7 @@ void CoverSliceShape(const FuncGraphPtr &root);
|
|||
|
||||
void SetVirtualDatasetStrategy(const CNodePtr &node);
|
||||
|
||||
// Creat parallel operator for primitive node(has strategy)
|
||||
// Create parallel operator for primitive node(has strategy)
|
||||
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true);
|
||||
|
||||
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
|
||||
|
@ -163,6 +163,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr
|
|||
|
||||
void SetLastNodeStrategy(const StrategyPtr strategyPtr);
|
||||
|
||||
bool CreateGroupsByCkptFile(const std::string &file);
|
||||
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,8 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
|
|||
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
|
||||
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
|
||||
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
|
||||
instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
|
||||
instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
@ -46,6 +48,39 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
|
||||
MS_EXCEPTION_IF_NULL(group_info_map);
|
||||
if (!CheckPointExit(file)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
|
||||
}
|
||||
straspb::ParallelGroupMap parallel_group_map;
|
||||
std::fstream input(file, std::ios::in | std::ios::binary);
|
||||
if (!parallel_group_map.ParseFromIstream(&input)) {
|
||||
MS_LOG(ERROR) << "Load strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
input.close();
|
||||
|
||||
size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
|
||||
for (size_t i = 0; i < group_num; ++i) {
|
||||
straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i));
|
||||
std::string group_name = parallel_group_item.group_name();
|
||||
|
||||
straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
|
||||
size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
|
||||
std::vector<uint32_t> ranks;
|
||||
for (size_t j = 0; j < rank_num; ++j) {
|
||||
uint32_t rank = parallel_group_ranks.dim(SizeToLong(j));
|
||||
ranks.push_back(rank);
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks);
|
||||
group_info_map->push_back(group);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
||||
if (strategy_map == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
|
||||
|
@ -141,5 +176,27 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
output.close();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
|
||||
straspb::ParallelGroupMap parallel_group_map;
|
||||
for (auto &group : group_info_map) {
|
||||
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_group_item);
|
||||
parallel_group_item->set_group_name(group.first);
|
||||
straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks();
|
||||
MS_EXCEPTION_IF_NULL(parallel_group_ranks);
|
||||
for (auto &rank : group.second) {
|
||||
parallel_group_ranks->add_dim(rank);
|
||||
}
|
||||
}
|
||||
|
||||
std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_group_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
output.close();
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@ namespace parallel {
|
|||
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
|
||||
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
|
||||
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
|
||||
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
|
||||
class StrategyCheckpoint {
|
||||
public:
|
||||
StrategyCheckpoint() {
|
||||
|
@ -40,11 +41,16 @@ class StrategyCheckpoint {
|
|||
load_checkpoint_on_ = false;
|
||||
save_file_ = "";
|
||||
save_checkpoint_on_ = false;
|
||||
group_info_save_file_ = "";
|
||||
group_info_save_on_ = false;
|
||||
}
|
||||
~StrategyCheckpoint() = default;
|
||||
|
||||
Status Load(StrategyMap *strategy_map);
|
||||
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
|
||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
||||
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
|
||||
bool group_info_save_on() const { return group_info_save_on_; }
|
||||
|
||||
static StrategyCheckpoint &GetInstance();
|
||||
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
|
||||
|
@ -57,6 +63,8 @@ class StrategyCheckpoint {
|
|||
bool save_checkpoint_on_;
|
||||
bool CheckPointExit(const std::string path) const;
|
||||
int64_t current_stage_;
|
||||
std::string group_info_save_file_;
|
||||
bool group_info_save_on_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -157,6 +157,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set strategy checkpoint save file.")
|
||||
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
|
||||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||
.def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.")
|
||||
.def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
|
||||
"Set pipeline stage split num.")
|
||||
.def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
|
||||
|
|
|
@ -61,6 +61,19 @@ message ParallelLayoutItem {
|
|||
required ParallelLayouts parallel_layouts = 2;
|
||||
}
|
||||
|
||||
message ParallelGroupRanks {
|
||||
repeated uint32 dim = 1;
|
||||
}
|
||||
|
||||
message ParallelGroupItem {
|
||||
required string group_name = 1;
|
||||
required ParallelGroupRanks parallel_group_ranks = 2;
|
||||
}
|
||||
|
||||
message ParallelGroupMap {
|
||||
repeated ParallelGroupItem parallel_group_item = 1;
|
||||
}
|
||||
|
||||
message ParallelStrategyMap {
|
||||
required uint32 current_stage = 1;
|
||||
repeated ParallelStrategyItem parallel_strategy_item = 2;
|
||||
|
|
|
@ -283,6 +283,15 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_ckpt_save_file()
|
||||
|
||||
def set_group_ckpt_save_file(self, group_ckpt_save_file):
|
||||
"""Set group checkpoint save path."""
|
||||
self.check_context_handle()
|
||||
import os
|
||||
dir_path = os.path.dirname(group_ckpt_save_file)
|
||||
if dir_path and not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
|
||||
|
||||
def get_parameter_broadcast_is_set(self):
|
||||
"""Get parameter broadcast is set or not."""
|
||||
self.check_context_handle()
|
||||
|
@ -505,6 +514,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
"group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().set_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||
|
@ -533,7 +543,7 @@ _get_auto_parallel_context_func_map = {
|
|||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list)
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -574,6 +584,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
|
||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
|
||||
|
|
|
@ -31,5 +31,9 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
|
|||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -75,7 +75,8 @@ def test_six_matmul_save():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
|
||||
group_ckpt_save_file="./group_stage1.ckpt")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((1, 8), (8, 1))
|
||||
strategy3 = ((2, 2), (2, 2))
|
||||
|
@ -137,7 +138,8 @@ def test_six_matmul_load():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
|
||||
group_ckpt_save_file="./group_stage1.ckpt")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy3 = ((8, 1), (1, 1))
|
||||
strategy4 = ((8, 1), (1, 1))
|
||||
|
|
Loading…
Reference in New Issue