forked from mindspore-Ecosystem/mindspore
auto parallel strategy checkpoint
This commit is contained in:
parent
420ef2a352
commit
6cde5f6d91
|
@ -52,7 +52,11 @@ class Primitive : public Named {
|
|||
: Named(name), signatures_(), prim_type_(prim_type) {}
|
||||
|
||||
Primitive(const Primitive &prim)
|
||||
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
|
||||
: Named(prim),
|
||||
attrs_(prim.attrs_),
|
||||
signatures_(prim.signatures_),
|
||||
instance_name_(prim.instance_name_),
|
||||
prim_type_(prim.prim_type_) {}
|
||||
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
|
||||
|
|
|
@ -56,6 +56,8 @@ void ParallelContext::Reset() {
|
|||
parameter_broadcast_ = false;
|
||||
parameter_broadcast_is_set_ = false;
|
||||
enable_all_reduce_fusion_ = false;
|
||||
strategy_ckpt_load_file_ = "";
|
||||
strategy_ckpt_save_file_ = "";
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int32_t device_num) {
|
||||
|
@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
|
|||
parameter_broadcast_is_set_ = true;
|
||||
}
|
||||
|
||||
void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
|
||||
strategy_ckpt_load_file_ = strategy_ckpt_load_file;
|
||||
}
|
||||
|
||||
void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
|
||||
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
|
||||
}
|
||||
|
||||
void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) {
|
||||
all_reduce_fusion_split_indices_ = indices;
|
||||
}
|
||||
|
|
|
@ -85,6 +85,11 @@ class ParallelContext {
|
|||
}
|
||||
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
|
||||
|
||||
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
|
||||
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
|
||||
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
|
||||
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
|
||||
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
|
@ -105,6 +110,8 @@ class ParallelContext {
|
|||
bool enable_all_reduce_fusion_;
|
||||
std::vector<uint32_t> all_reduce_fusion_split_indices_;
|
||||
std::vector<uint32_t> all_reduce_fusion_split_sizes_;
|
||||
std::string strategy_ckpt_load_file_;
|
||||
std::string strategy_ckpt_save_file_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "parallel/context.h"
|
||||
#include "parallel/ops_info/tmp_identity_info.h"
|
||||
#include "parallel/step_parallel.h"
|
||||
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "pipeline/pipeline.h"
|
||||
|
||||
|
@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
|||
return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
|
||||
}
|
||||
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto attrs = prim->attrs();
|
||||
|
@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
operator_info->set_input_value(input_value);
|
||||
operator_info->set_outputs_dtype(cnode->Type());
|
||||
operator_info->set_cnode(cnode);
|
||||
// key of strategy map
|
||||
std::string instance_name = prim->instance_name();
|
||||
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
||||
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy
|
||||
if (!StrategyFound(attrs) || prim->name() == CAST) {
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
|
||||
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
|
||||
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
|
||||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
|
@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
}
|
||||
} else {
|
||||
// In this case, the configured strategy should be extracted to help setting cost
|
||||
StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs);
|
||||
StrategyPtr strategyPtr;
|
||||
if (load_strategy_from_ckpt) {
|
||||
strategyPtr = (*stra_map)[strategy_key_name];
|
||||
} else {
|
||||
strategyPtr = parallel::ExtractStrategy(attrs);
|
||||
}
|
||||
if (strategyPtr != nullptr) {
|
||||
if (prim->name() == RESHAPE) {
|
||||
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
|
||||
|
@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
// The map from CNode's UniqueId to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
|
||||
// extract strategy from checkpoint for multi-train
|
||||
StrategyMap stra_map;
|
||||
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
||||
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
// Step 1
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
|
@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
|
||||
auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
|
||||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode);
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
|
||||
// extract strategy from checkpoint for multi-train
|
||||
StrategyMap stra_map;
|
||||
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
||||
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
|
||||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
// In this case, the corresponding OperatorInfo is not created, create the new one.
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode);
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
}
|
||||
|
||||
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
// load strategy map from checkpoint
|
||||
StrategyMap stra_map;
|
||||
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
||||
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
|
@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
(void)cnode->set_operator_info(operator_);
|
||||
continue;
|
||||
}
|
||||
if (!StrategyFound(attrs)) {
|
||||
// load strategy checkpoint
|
||||
// key of strategy map
|
||||
std::string instance_name = prim->instance_name();
|
||||
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||
|
||||
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
|
||||
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
||||
<< " is empty, using batch parallel";
|
||||
std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
||||
|
@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
|
||||
<< attrs[GEN_STRATEGY]->ToString();
|
||||
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
||||
} else if (load_strategy_from_ckpt) {
|
||||
strategyPtr = stra_map[strategy_key_name];
|
||||
} else {
|
||||
strategyPtr = ExtractStrategy(attrs);
|
||||
}
|
||||
|
@ -2022,16 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
|
|||
}
|
||||
}
|
||||
|
||||
bool NodeWithParameter(const CNodePtr &node) {
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
for (auto input : node_inputs) {
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(INFO) << "Save strategy to checkpoint begin";
|
||||
StrategyMap straMap;
|
||||
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
||||
StrategyMap stra_map;
|
||||
auto ret = func_graph->get_return();
|
||||
auto all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
|
||||
continue;
|
||||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -2039,57 +2068,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
if (operator_info) {
|
||||
if (prim->instance_name().empty()) {
|
||||
continue;
|
||||
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
|
||||
}
|
||||
std::string instance_name = prim->instance_name();
|
||||
StrategyPtr strategyPtr = operator_info->strategy();
|
||||
MS_EXCEPTION_IF_NULL(node->scope());
|
||||
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
||||
straMap[node_name] = strategyPtr;
|
||||
stra_map[node_name] = strategyPtr;
|
||||
}
|
||||
}
|
||||
if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) {
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
|
||||
void RestoreStrategy(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(INFO) << "Extract strategy from checkpoint begin";
|
||||
StrategyMap straMap;
|
||||
if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed";
|
||||
}
|
||||
auto ret = func_graph->get_return();
|
||||
auto all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
if (operator_info) {
|
||||
if (prim->instance_name().empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string instance_name = prim->instance_name();
|
||||
MS_EXCEPTION_IF_NULL(node->scope());
|
||||
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
|
||||
MS_LOG(INFO) << "Node name is " << node_name;
|
||||
if (straMap.find(node_name) != straMap.end()) {
|
||||
StrategyPtr strategyPtr = straMap[node_name];
|
||||
operator_info->set_strategy(strategyPtr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
// extract shape and strategy, set operator_info
|
||||
ExtractInformation(all_nodes);
|
||||
ReshapeInit(all_nodes);
|
||||
// extract strategy from checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) {
|
||||
RestoreStrategy(root);
|
||||
}
|
||||
}
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().CheckPointOn() &&
|
||||
StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) {
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(root);
|
||||
}
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
|
|||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const FuncGraphManagerPtr &manager);
|
||||
|
||||
void RestoreStrategy(const FuncGraphPtr &func_graph);
|
||||
bool NodeWithParameter(const CNodePtr &node);
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph);
|
||||
|
||||
|
|
|
@ -29,30 +29,32 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
|
||||
static StrategyCheckpoint instance = StrategyCheckpoint();
|
||||
if (ParallelContext::GetInstance() != nullptr) {
|
||||
instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file();
|
||||
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
|
||||
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
|
||||
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool StrategyCheckpoint::CheckPointExit() const {
|
||||
std::ifstream fin(path_);
|
||||
bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
|
||||
std::ifstream fin(path);
|
||||
if (fin) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::RemoveCheckPoint() const {
|
||||
if (std::remove(common::SafeCStr(path_)) == 0) {
|
||||
return SUCCESS;
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
||||
if (strategy_map == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
|
||||
}
|
||||
if (!CheckPointExit(load_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
|
||||
}
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
std::fstream input(path_, std::ios::in | std::ios::binary);
|
||||
std::fstream input(load_file_, std::ios::in | std::ios::binary);
|
||||
if (!parallel_strategy_map.ParseFromIstream(&input)) {
|
||||
MS_LOG(ERROR) << "Load strategy file failed";
|
||||
return FAILED;
|
||||
|
@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
|
||||
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
|
||||
(*strategy_map)[node_name] = strategy;
|
||||
current_train_time_ = (int32_t)parallel_strategy_map.train_time();
|
||||
current_stage_ = (int32_t)parallel_strategy_map.current_stage();
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
parallel_strategy_map.set_train_time(IntToUint(++current_train_time_));
|
||||
parallel_strategy_map.set_current_stage(IntToUint(++current_stage_));
|
||||
for (auto &node_stra : strategy_map) {
|
||||
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
|
||||
|
@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
|
|||
}
|
||||
}
|
||||
}
|
||||
std::fstream output(path_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_strategy_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
return FAILED;
|
||||
|
|
|
@ -21,43 +21,37 @@
|
|||
#include <unordered_map>
|
||||
#include "parallel/ops_info/ops_utils.h"
|
||||
#include "parallel/strategy.h"
|
||||
#include "parallel/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr char DEFAULT_CHECKPOINT_PATH[] = "./strategys.ckpt";
|
||||
|
||||
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
|
||||
class StrategyCheckpoint {
|
||||
public:
|
||||
StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) {
|
||||
train_times_ = 1;
|
||||
checkpoint_on_ = false;
|
||||
const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES");
|
||||
if (train_times_str != nullptr && std::stoi(train_times_str) > 0) {
|
||||
train_times_ = std::stoi(train_times_str);
|
||||
}
|
||||
const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON");
|
||||
if (checkpoint_on_str != nullptr) {
|
||||
checkpoint_on_ = (std::string(checkpoint_on_str) == "on");
|
||||
}
|
||||
StrategyCheckpoint() {
|
||||
current_stage_ = 0;
|
||||
load_file_ = "";
|
||||
load_checkpoint_on_ = false;
|
||||
save_file_ = "";
|
||||
save_checkpoint_on_ = false;
|
||||
}
|
||||
~StrategyCheckpoint() = default;
|
||||
bool CheckPointExit() const;
|
||||
Status RemoveCheckPoint() const;
|
||||
|
||||
Status Load(StrategyMap *strategy_map);
|
||||
Status Save(const StrategyMap &strategy_map);
|
||||
|
||||
static StrategyCheckpoint &GetInstance();
|
||||
int32_t GetTrainTimes() const { return train_times_; }
|
||||
int32_t GetCurrentTrainTime() const { return current_train_time_; }
|
||||
bool CheckPointOn() const { return checkpoint_on_; }
|
||||
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
|
||||
bool SaveCheckPointOn() const { return save_checkpoint_on_; }
|
||||
|
||||
private:
|
||||
std::string path_;
|
||||
bool checkpoint_on_;
|
||||
// total train times for a train, get from Environmental variable:TRAIN_TIME, please export it
|
||||
int32_t train_times_;
|
||||
int32_t current_train_time_;
|
||||
std::string load_file_;
|
||||
std::string save_file_;
|
||||
bool load_checkpoint_on_;
|
||||
bool save_checkpoint_on_;
|
||||
bool CheckPointExit(const std::string path) const;
|
||||
int32_t current_stage_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -191,6 +191,12 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
|
||||
"Get parameter broadcast is set.")
|
||||
.def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
|
||||
.def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
|
||||
"Set strategy checkpoint load file.")
|
||||
.def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
|
||||
"Set strategy checkpoint save file.")
|
||||
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
|
||||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -33,6 +33,6 @@ message ParallelStrategyItem {
|
|||
}
|
||||
|
||||
message ParallelStrategyMap {
|
||||
required uint32 train_time = 1;
|
||||
required uint32 current_stage = 1;
|
||||
repeated ParallelStrategyItem parallel_strategy_item = 2;
|
||||
}
|
|
@ -404,7 +404,7 @@ def _context():
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
||||
parameter_broadcast=bool)
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -436,6 +436,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -447,6 +449,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(cast_before_mirror=False)
|
||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
@ -477,6 +481,8 @@ def reset_auto_parallel_context():
|
|||
- cast_before_mirror: True.
|
||||
- parallel_mode: "stand_alone".
|
||||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: "".
|
||||
- strategy_ckpt_save_file: "".
|
||||
"""
|
||||
_reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -88,6 +88,8 @@ class Primitive(Primitive_):
|
|||
for name in self.attrs:
|
||||
value = self.attrs[name]
|
||||
cloned.add_prim_attr(name, value)
|
||||
if hasattr(self, 'instance_name'):
|
||||
cloned.set_prim_instance_name(self.instance_name)
|
||||
return cloned
|
||||
|
||||
def add_prim_attr(self, name, value):
|
||||
|
|
|
@ -208,6 +208,36 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_parameter_broadcast()
|
||||
|
||||
def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
|
||||
"""
|
||||
Set strategy checkpoint load path.
|
||||
|
||||
Args:
|
||||
strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
|
||||
|
||||
def get_strategy_ckpt_load_file(self):
|
||||
"""Get strategy checkpoint load path."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_ckpt_load_file()
|
||||
|
||||
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
|
||||
"""
|
||||
Set strategy checkpoint save path.
|
||||
|
||||
Args:
|
||||
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
|
||||
|
||||
def get_strategy_ckpt_save_file(self):
|
||||
"""Get strategy checkpoint save path."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_ckpt_save_file()
|
||||
|
||||
def get_parameter_broadcast_is_set(self):
|
||||
"""Get parameter broadcast is set or not."""
|
||||
self.check_context_handle()
|
||||
|
@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = {
|
|||
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
|
||||
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast}
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = {
|
|||
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
|
||||
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast}
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool)
|
||||
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
|
||||
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs):
|
|||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -400,5 +437,7 @@ def _reset_auto_parallel_context():
|
|||
- cast_before_mirror: True.
|
||||
- parallel_mode: "stand_alone".
|
||||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: ""
|
||||
- strategy_ckpt_save_file: ""
|
||||
"""
|
||||
auto_parallel_context().reset()
|
||||
|
|
|
@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
bool StrategyCheckpoint::CheckPointExit() const { return false; }
|
||||
|
||||
Status StrategyCheckpoint::RemoveCheckPoint() const { return SUCCESS; }
|
||||
bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; }
|
||||
|
||||
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
|
||||
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.context import set_auto_parallel_context
|
||||
from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
import mindspore as ms
|
||||
from mindspore.common.api import _executor
|
||||
|
@ -25,17 +25,15 @@ from mindspore.ops import composite as C
|
|||
|
||||
|
||||
# model_parallel test
|
||||
# export PARALLEL_CHECKPOINT_ON=on
|
||||
# export PARALLEL_TRAIN_TIMES=4
|
||||
def test_six_matmul():
|
||||
def test_six_matmul_save():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, x6, x7):
|
||||
predict = self.network(x1, x2, x3, x4, x5, x6, x7)
|
||||
def construct(self, x1, x6):
|
||||
predict = self.network(x1, x6)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
|
@ -44,8 +42,8 @@ def test_six_matmul():
|
|||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, x6, x7):
|
||||
return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7)
|
||||
def construct(self, x1, x6):
|
||||
return C.grad_all(self.network)(x1, x6)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
|
||||
|
@ -56,45 +54,46 @@ def test_six_matmul():
|
|||
self.matmul4 = P.MatMul().set_strategy(strategy4)
|
||||
self.matmul5 = P.MatMul().set_strategy(strategy5)
|
||||
self.matmul6 = P.MatMul().set_strategy(strategy6)
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
|
||||
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
|
||||
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
|
||||
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, x6, x7):
|
||||
out = self.matmul1(x1, x2)
|
||||
out = self.matmul2(out, x3)
|
||||
out = self.matmul3(out, x4)
|
||||
out = self.matmul4(out, x5)
|
||||
out = self.matmul5(out, x6)
|
||||
out = self.matmul6(out, x7)
|
||||
def construct(self, x1, x6):
|
||||
out = self.matmul1(x1, self.weight1)
|
||||
out = self.matmul2(out, self.weight2)
|
||||
out = self.matmul3(out, self.weight3)
|
||||
out = self.matmul4(out, self.weight4)
|
||||
out = self.matmul5(out, self.weight5)
|
||||
out = self.matmul6(out, x6)
|
||||
return out
|
||||
|
||||
set_auto_parallel_context(device_num=512, global_rank=0)
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((1, 8), (8, 1))
|
||||
strategy3 = ((2, 2), (2, 2))
|
||||
strategy4 = ((4, 2), (2, 4))
|
||||
strategy5 = ((2, 4), (4, 2))
|
||||
strategy6 = ((4, 4), (4, 4))
|
||||
strategy4 = ((1, 1), (1, 8))
|
||||
strategy5 = ((4, 2), (2, 1))
|
||||
strategy6 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
x3 = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x2, x3, x4, x5, x6, x7)
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x6)
|
||||
|
||||
# remove matmul2
|
||||
def test_six_matmul_repeated1():
|
||||
# remove matmul2, add matmul7
|
||||
def test_six_matmul_load():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7):
|
||||
predict = self.network(x1, x2, x4, x5, x6, x7)
|
||||
def construct(self, x1, x6, x7):
|
||||
predict = self.network(x1, x6, x7)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
|
@ -103,63 +102,8 @@ def test_six_matmul_repeated1():
|
|||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7):
|
||||
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().set_strategy(strategy1)
|
||||
self.matmul3 = P.MatMul().set_strategy(strategy3)
|
||||
self.matmul4 = P.MatMul().set_strategy(strategy4)
|
||||
self.matmul5 = P.MatMul().set_strategy(strategy5)
|
||||
self.matmul6 = P.MatMul().set_strategy(strategy6)
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7):
|
||||
out = self.matmul1(x1, x2)
|
||||
out = self.matmul3(out, x4)
|
||||
out = self.matmul4(out, x5)
|
||||
out = self.matmul5(out, x6)
|
||||
out = self.matmul6(out, x7)
|
||||
return out
|
||||
|
||||
set_auto_parallel_context(device_num=512, global_rank=0)
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy3 = ((8, 1), (1, 1))
|
||||
strategy4 = ((8, 1), (1, 1))
|
||||
strategy5 = ((8, 1), (1, 1))
|
||||
strategy6 = ((8, 1), (1, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x2, x4, x5, x6, x7)
|
||||
|
||||
# add matmul7
|
||||
def test_six_matmul_repeated2():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8):
|
||||
predict = self.network(x1, x2, x4, x5, x6, x7, x8)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8):
|
||||
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8)
|
||||
def construct(self, x1, x6, x7):
|
||||
return C.grad_all(self.network)(x1, x6, x7)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
|
||||
|
@ -170,17 +114,22 @@ def test_six_matmul_repeated2():
|
|||
self.matmul5 = P.MatMul().set_strategy(strategy5)
|
||||
self.matmul6 = P.MatMul().set_strategy(strategy6)
|
||||
self.matmul7 = P.MatMul().set_strategy(strategy7)
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
|
||||
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
|
||||
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
|
||||
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8):
|
||||
out = self.matmul1(x1, x2)
|
||||
out = self.matmul3(out, x4)
|
||||
out = self.matmul4(out, x5)
|
||||
out = self.matmul5(out, x6)
|
||||
out = self.matmul6(out, x7)
|
||||
out = self.matmul7(out, x8)
|
||||
def construct(self, x1, x6, x7):
|
||||
out = self.matmul1(x1, self.weight1)
|
||||
out = self.matmul3(out, self.weight3)
|
||||
out = self.matmul4(out, self.weight4)
|
||||
out = self.matmul5(out, self.weight5)
|
||||
out = self.matmul6(out, x6)
|
||||
out = self.matmul7(out, x7)
|
||||
return out
|
||||
|
||||
set_auto_parallel_context(device_num=512, global_rank=0)
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy3 = ((8, 1), (1, 1))
|
||||
strategy4 = ((8, 1), (1, 1))
|
||||
|
@ -190,28 +139,21 @@ def test_six_matmul_repeated2():
|
|||
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8)
|
||||
_executor.compile(net, x1, x6, x7)
|
||||
|
||||
|
||||
# add scope2
|
||||
def test_six_matmul_repeated3():
|
||||
# model_parallel test
|
||||
def test_six_matmul_save_auto():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network1, network2):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network1
|
||||
self.network2 = network2
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
|
||||
predict = self.network(x1, x2, x4, x5, x6, x7, x8)
|
||||
predict = self.network2(predict, x9, x10)
|
||||
def construct(self, x1, x6):
|
||||
predict = self.network(x1, x6)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
|
@ -220,62 +162,96 @@ def test_six_matmul_repeated3():
|
|||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
|
||||
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10)
|
||||
def construct(self, x1, x6):
|
||||
return C.grad_all(self.network)(x1, x6)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul3 = P.MatMul()
|
||||
self.matmul4 = P.MatMul()
|
||||
self.matmul5 = P.MatMul()
|
||||
self.matmul6 = P.MatMul()
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
|
||||
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
|
||||
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
|
||||
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
|
||||
|
||||
def construct(self, x1, x6):
|
||||
out = self.matmul1(x1, self.weight1)
|
||||
out = self.matmul2(out, self.weight2)
|
||||
out = self.matmul3(out, self.weight3)
|
||||
out = self.matmul4(out, self.weight4)
|
||||
out = self.matmul5(out, self.weight5)
|
||||
out = self.matmul6(out, x6)
|
||||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x6)
|
||||
|
||||
# remove matmul2, add matmul7
|
||||
def test_six_matmul_load_auto():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x6, x7):
|
||||
predict = self.network(x1, x6, x7)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x6, x7):
|
||||
return C.grad_all(self.network)(x1, x6, x7)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy3, strategy4, strategy5):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().set_strategy(strategy1)
|
||||
self.matmul3 = P.MatMul().set_strategy(strategy3)
|
||||
self.matmul4 = P.MatMul().set_strategy(strategy4)
|
||||
self.matmul5 = P.MatMul().set_strategy(strategy5)
|
||||
self.matmul6 = P.MatMul().set_strategy(strategy6)
|
||||
self.matmul7 = P.MatMul().set_strategy(strategy7)
|
||||
self.matmul6 = P.MatMul()
|
||||
self.matmul7 = P.MatMul()
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
|
||||
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
|
||||
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
|
||||
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
|
||||
|
||||
def construct(self, x1, x2, x4, x5, x6, x7, x8):
|
||||
out = self.matmul1(x1, x2)
|
||||
out = self.matmul3(out, x4)
|
||||
out = self.matmul4(out, x5)
|
||||
out = self.matmul5(out, x6)
|
||||
out = self.matmul6(out, x7)
|
||||
out = self.matmul7(out, x8)
|
||||
def construct(self, x1, x6, x7):
|
||||
out = self.matmul1(x1, self.weight1)
|
||||
out = self.matmul3(out, self.weight3)
|
||||
out = self.matmul4(out, self.weight4)
|
||||
out = self.matmul5(out, self.weight5)
|
||||
out = self.matmul6(out, x6)
|
||||
out = self.matmul7(out, x7)
|
||||
return out
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().set_strategy(strategy1)
|
||||
self.matmul2 = P.MatMul().set_strategy(strategy2)
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy3 = ((2, 2), (2, 2))
|
||||
strategy4 = ((2, 2), (2, 2))
|
||||
strategy5 = ((2, 2), (2, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
out = self.matmul1(x1, x2)
|
||||
out = self.matmul2(out, x3)
|
||||
return out
|
||||
|
||||
|
||||
set_auto_parallel_context(device_num=512, global_rank=0)
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy3 = ((8, 1), (1, 1))
|
||||
strategy4 = ((8, 1), (1, 1))
|
||||
strategy5 = ((8, 1), (1, 1))
|
||||
strategy6 = ((8, 1), (1, 1))
|
||||
strategy7 = ((8, 1), (1, 1))
|
||||
strategy8 = ((8, 1), (1, 1))
|
||||
strategy9 = ((8, 1), (1, 1))
|
||||
net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)
|
||||
net2 = Net1(strategy8, strategy9)
|
||||
net = GradWrap(NetWithLoss(net1, net2))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
x9 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
x10 = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10)
|
||||
|
||||
_executor.compile(net, x1, x6, x7)
|
Loading…
Reference in New Issue