forked from mindspore-Ecosystem/mindspore
!47970 strategy_ckpt_viewable
Merge pull request !47970 from yao_yf/strategy_ckpt_viewable
This commit is contained in:
commit
73a445c03a
|
@ -55,12 +55,15 @@ if(ENABLE_D OR ENABLE_GPU)
|
|||
endif()
|
||||
|
||||
if(NOT BUILD_LITE)
|
||||
list(APPEND MSLIB_SRC_DEPEND
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ endif()
|
|||
|
||||
if(ENABLE_DUMP_PROTO)
|
||||
list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
|
||||
list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/strategy_checkpoint_info.cc")
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
|
|
|
@ -2239,7 +2239,7 @@ static void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const F
|
|||
}
|
||||
tensor_info_map[cloned_param_name] = cloned_param_layout;
|
||||
}
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1268,7 +1268,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
|
|||
auto input = node_inputs[LongToSize(i)];
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
|
||||
if (input_parameter->has_default()) {
|
||||
(void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
|
||||
}
|
||||
} else if (input->isa<CNode>()) {
|
||||
|
|
|
@ -19,17 +19,16 @@
|
|||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "include/common/debug/common.h"
|
||||
#include "proto/node_strategy.pb.h"
|
||||
#include "mindspore/core/utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const uint32_t JSON_SUFFIX_LENGTH = 5;
|
||||
StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
|
||||
static StrategyCheckpoint instance = StrategyCheckpoint();
|
||||
if (ParallelContext::GetInstance() != nullptr) {
|
||||
|
@ -39,6 +38,10 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
|
|||
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
|
||||
instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
|
||||
instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
|
||||
instance.load_format_json_ = instance.load_file_.size() >= JSON_SUFFIX_LENGTH &&
|
||||
instance.load_file_.substr(instance.load_file_.size() - JSON_SUFFIX_LENGTH) == ".json";
|
||||
instance.save_format_json_ = instance.save_file_.size() >= JSON_SUFFIX_LENGTH &&
|
||||
instance.save_file_.substr(instance.save_file_.size() - JSON_SUFFIX_LENGTH) == ".json";
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
@ -110,96 +113,47 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
if (!CheckPointExit(load_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
|
||||
}
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
std::fstream input(load_file_, std::ios::in | std::ios::binary);
|
||||
if (!parallel_strategy_map.ParseFromIstream(&input)) {
|
||||
MS_LOG(ERROR) << "Load strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
input.close();
|
||||
size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
|
||||
for (size_t i = 0; i < node_num; i++) {
|
||||
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
|
||||
std::string node_name = parallel_strategy_item.node_name();
|
||||
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
|
||||
int64_t stage = SizeToLong(parallel_strategys.stage());
|
||||
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
|
||||
Strategies strategy_inputs;
|
||||
for (size_t j = 0; j < strategys_num; j++) {
|
||||
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
|
||||
Dimensions dimension;
|
||||
size_t dim_num = LongToSize(parallel_strategy.dim_size());
|
||||
for (size_t k = 0; k < dim_num; k++) {
|
||||
dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
|
||||
}
|
||||
strategy_inputs.push_back(dimension);
|
||||
if (load_format_json_) {
|
||||
std::fstream input(load_file_, std::ios::in);
|
||||
nlohmann::json stra_ckpt_info_j;
|
||||
input >> stra_ckpt_info_j;
|
||||
strategy_checkpoint_info_.from_json(stra_ckpt_info_j);
|
||||
} else {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
std::fstream input(load_file_, std::ios::in | std::ios::binary);
|
||||
if (!parallel_strategy_map.ParseFromIstream(&input)) {
|
||||
MS_LOG(ERROR) << "Load strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
|
||||
(*strategy_map)[node_name] = strategy;
|
||||
current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
|
||||
input.close();
|
||||
strategy_checkpoint_info_.from_protobuf(parallel_strategy_map);
|
||||
}
|
||||
*strategy_map = strategy_checkpoint_info_.strategy_map();
|
||||
current_stage_ = SizeToLong(strategy_checkpoint_info_.current_stage());
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(++current_stage_)));
|
||||
for (auto &node_stra : strategy_map) {
|
||||
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
|
||||
parallel_strategy_item->set_node_name(node_stra.first);
|
||||
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategys);
|
||||
MS_EXCEPTION_IF_NULL(node_stra.second);
|
||||
parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
|
||||
for (auto &dims : node_stra.second->GetInputDim()) {
|
||||
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy);
|
||||
for (auto stra_dim : dims) {
|
||||
parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &node_tensor_info : tensor_info_map) {
|
||||
TensorLayoutPtr tensor_layout = node_tensor_info.second;
|
||||
MS_EXCEPTION_IF_NULL(tensor_layout);
|
||||
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
||||
parallel_layout_item->set_param_name(node_tensor_info.first);
|
||||
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
|
||||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto map_dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(LongToInt(map_dim));
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
|
||||
MS_EXCEPTION_IF_NULL(manual_shape_map);
|
||||
auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
|
||||
for (auto dim_pair : manual_shape) {
|
||||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
|
||||
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
|
||||
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
|
||||
}
|
||||
const ManualShapeMap &manual_shape_map) {
|
||||
if (!CheckPath(save_file_)) {
|
||||
MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
|
||||
}
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_strategy_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
return FAILED;
|
||||
strategy_checkpoint_info_.Init(strategy_map, tensor_info_map, manual_shape_map, ++current_stage_);
|
||||
if (save_format_json_) {
|
||||
auto stra_ckpt_info_j = strategy_checkpoint_info_.to_json();
|
||||
std::fstream output(save_file_, std::ios::out);
|
||||
stra_ckpt_info_j >> output;
|
||||
output.close();
|
||||
} else {
|
||||
auto parallel_strategy_map = strategy_checkpoint_info_.to_protobuf();
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_strategy_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
return FAILED;
|
||||
}
|
||||
output.close();
|
||||
}
|
||||
output.close();
|
||||
|
||||
ChangeFileMode(save_file_, S_IRUSR | S_IWUSR);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -22,20 +22,14 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/hash_map.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>;
|
||||
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
|
||||
using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>;
|
||||
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
|
||||
using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>;
|
||||
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
|
||||
class StrategyCheckpoint {
|
||||
public:
|
||||
StrategyCheckpoint() {
|
||||
|
@ -47,7 +41,8 @@ class StrategyCheckpoint {
|
|||
|
||||
Status Load(StrategyMap *strategy_map);
|
||||
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) const;
|
||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
const ManualShapeMap &manual_shape_map);
|
||||
Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
|
||||
bool group_info_save_on() const { return group_info_save_on_; }
|
||||
|
||||
|
@ -65,6 +60,9 @@ class StrategyCheckpoint {
|
|||
int64_t current_stage_ = 0;
|
||||
std::string group_info_save_file_;
|
||||
bool group_info_save_on_ = false;
|
||||
bool load_format_json_ = true;
|
||||
bool save_format_json_ = true;
|
||||
StrategyCheckpointInfo strategy_checkpoint_info_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
void StrategyCheckpointInfo::set_strategy_map(const StrategyMap &strategy_map) { strategy_map_ = strategy_map; }
|
||||
|
||||
void StrategyCheckpointInfo::set_tensor_info_map(const TensorInfoMap &tensor_info_map) {
|
||||
tensor_info_map_ = tensor_info_map;
|
||||
}
|
||||
|
||||
void StrategyCheckpointInfo::set_manual_shape_map(const ManualShapeMap &manual_shape_map) {
|
||||
manual_shape_map_ = manual_shape_map;
|
||||
}
|
||||
|
||||
void StrategyCheckpointInfo::from_json(const nlohmann::json &stra_ckpt_info_j) {
|
||||
current_stage_ = stra_ckpt_info_j.at("current_stage").get<int64_t>();
|
||||
for (const auto &stra_j : stra_ckpt_info_j.at("parallel_strategy_item").items()) {
|
||||
auto node_name = stra_j.key();
|
||||
auto stage = stra_j.value().at("stage").get<int64_t>();
|
||||
auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
|
||||
strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
|
||||
}
|
||||
for (const auto &layout_j : stra_ckpt_info_j.at("parallel_layout_item").items()) {
|
||||
auto parameter_name = layout_j.key();
|
||||
auto dev_matrix = layout_j.value().at("dev_matrix").get<std::vector<int64_t>>();
|
||||
auto tensor_map = layout_j.value().at("tensor_map").get<std::vector<int64_t>>();
|
||||
auto tensor_shape = layout_j.value().at("tensor_shape").get<std::vector<int64_t>>();
|
||||
auto field = layout_j.value().at("field").get<int64_t>();
|
||||
auto opt_weight_shard_step = layout_j.value().at("opt_weight_shard_step").get<int64_t>();
|
||||
auto opt_weight_shard_size = layout_j.value().at("opt_weight_shard_size").get<int64_t>();
|
||||
if (layout_j.value().contains("param_split_shape") && layout_j.value().contains("indices_offset")) {
|
||||
auto param_split_shape = layout_j.value().at("param_split_shape").get<std::vector<int64_t>>();
|
||||
auto indices_offset = layout_j.value().at("indices_offset").get<std::vector<int64_t>>();
|
||||
if (param_split_shape.size() != indices_offset.size()) {
|
||||
MS_LOG(EXCEPTION) << "For field_split strategy, the size of param_split_shape " << param_split_shape.size()
|
||||
<< " is not equal to the size of indices_offset " << indices_offset.size();
|
||||
}
|
||||
for (size_t i = 0; i < param_split_shape.size(); ++i) {
|
||||
manual_shape_map_[parameter_name].push_back({param_split_shape[i], indices_offset[i]});
|
||||
}
|
||||
}
|
||||
tensor_info_map_[parameter_name] = std::make_shared<TensorLayout>();
|
||||
tensor_info_map_[parameter_name]->InitFromVector(dev_matrix, tensor_map, tensor_shape);
|
||||
tensor_info_map_[parameter_name]->set_opt_weight_shard_size(opt_weight_shard_size);
|
||||
tensor_info_map_[parameter_name]->set_opt_weight_shard_step(opt_weight_shard_step);
|
||||
tensor_info_map_[parameter_name]->set_field_size(field);
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json StrategyCheckpointInfo::to_json() const {
|
||||
nlohmann::json stra_ckpt_info_j;
|
||||
stra_ckpt_info_j["current_stage"] = current_stage_;
|
||||
for (const auto &stra_pair : strategy_map_) {
|
||||
auto node_name = stra_pair.first;
|
||||
auto node_stra = stra_pair.second;
|
||||
nlohmann::json stra_j;
|
||||
stra_j["stage"] = node_stra->GetInputStage();
|
||||
stra_j["parallel_strategy"] = node_stra->GetInputDim();
|
||||
stra_ckpt_info_j["parallel_strategy_item"][node_name] = stra_j;
|
||||
}
|
||||
for (const auto &layout_pair : tensor_info_map_) {
|
||||
auto parameter_name = layout_pair.first;
|
||||
auto layout = layout_pair.second;
|
||||
nlohmann::json layout_j;
|
||||
layout_j["dev_matrix"] = layout->device_arrangement().array();
|
||||
layout_j["tensor_map"] = layout->tensor_map().array();
|
||||
layout_j["tensor_shape"] = layout->tensor_shape().array();
|
||||
layout_j["field"] = layout->get_field_size();
|
||||
layout_j["opt_weight_shard_step"] = layout->opt_weight_shard_step();
|
||||
layout_j["opt_weight_shard_size"] = layout->opt_weight_shard_size();
|
||||
if (manual_shape_map_.find(parameter_name) != manual_shape_map_.end()) {
|
||||
auto manual_shape = manual_shape_map_.at(parameter_name);
|
||||
for (auto dim_pair : manual_shape) {
|
||||
layout_j["param_split_shape"].push_back(dim_pair.first);
|
||||
layout_j["indices_offset"].push_back(dim_pair.second);
|
||||
}
|
||||
}
|
||||
stra_ckpt_info_j["parallel_layout_item"][parameter_name] = layout_j;
|
||||
}
|
||||
return stra_ckpt_info_j;
|
||||
}
|
||||
|
||||
void StrategyCheckpointInfo::from_protobuf(const straspb::ParallelStrategyMap ¶llel_strategy_map) {
|
||||
size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
|
||||
for (size_t i = 0; i < node_num; i++) {
|
||||
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
|
||||
std::string node_name = parallel_strategy_item.node_name();
|
||||
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
|
||||
int64_t stage = SizeToLong(parallel_strategys.stage());
|
||||
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
|
||||
Strategies strategy_inputs;
|
||||
for (size_t j = 0; j < strategys_num; j++) {
|
||||
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
|
||||
Dimensions dimension;
|
||||
size_t dim_num = LongToSize(parallel_strategy.dim_size());
|
||||
for (size_t k = 0; k < dim_num; k++) {
|
||||
dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
|
||||
}
|
||||
strategy_inputs.push_back(dimension);
|
||||
}
|
||||
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
|
||||
strategy_map_[node_name] = strategy;
|
||||
current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
|
||||
}
|
||||
}
|
||||
|
||||
straspb::ParallelStrategyMap StrategyCheckpointInfo::to_protobuf() const {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(current_stage_)));
|
||||
for (auto &node_stra : strategy_map_) {
|
||||
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
|
||||
parallel_strategy_item->set_node_name(node_stra.first);
|
||||
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategys);
|
||||
MS_EXCEPTION_IF_NULL(node_stra.second);
|
||||
parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
|
||||
for (auto &dims : node_stra.second->GetInputDim()) {
|
||||
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy);
|
||||
for (auto stra_dim : dims) {
|
||||
parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &node_tensor_info : tensor_info_map_) {
|
||||
TensorLayoutPtr tensor_layout = node_tensor_info.second;
|
||||
MS_EXCEPTION_IF_NULL(tensor_layout);
|
||||
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
||||
parallel_layout_item->set_param_name(node_tensor_info.first);
|
||||
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
|
||||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto map_dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(LongToInt(map_dim));
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
|
||||
parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
|
||||
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
|
||||
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
|
||||
if (manual_shape_map_.find(node_tensor_info.first) != manual_shape_map_.end()) {
|
||||
auto manual_shape = manual_shape_map_.at(node_tensor_info.first);
|
||||
for (auto dim_pair : manual_shape) {
|
||||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
return parallel_strategy_map;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "utils/hash_map.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
#include "proto/node_strategy.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>;
|
||||
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
|
||||
using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>;
|
||||
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
|
||||
using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>;
|
||||
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
|
||||
|
||||
class StrategyCheckpointInfo {
|
||||
public:
|
||||
StrategyCheckpointInfo() {}
|
||||
~StrategyCheckpointInfo() = default;
|
||||
void Init(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
const ManualShapeMap &manual_shape_map, int64_t current_stage) {
|
||||
strategy_map_ = strategy_map;
|
||||
tensor_info_map_ = tensor_info_map;
|
||||
manual_shape_map_ = manual_shape_map;
|
||||
current_stage_ = current_stage;
|
||||
}
|
||||
StrategyMap strategy_map() const { return strategy_map_; }
|
||||
void set_strategy_map(const StrategyMap &strategy_map);
|
||||
TensorInfoMap tensor_info_map() const { return tensor_info_map_; }
|
||||
void set_tensor_info_map(const TensorInfoMap &tensor_info_map);
|
||||
ManualShapeMap manual_shape_map() const { return manual_shape_map_; }
|
||||
void set_manual_shape_map(const ManualShapeMap &manual_shape_map);
|
||||
int64_t current_stage() const { return current_stage_; }
|
||||
|
||||
void from_json(const nlohmann::json &stra_ckpt_info_j);
|
||||
nlohmann::json to_json() const;
|
||||
|
||||
void from_protobuf(const straspb::ParallelStrategyMap ¶llel_strategy_map);
|
||||
straspb::ParallelStrategyMap to_protobuf() const;
|
||||
|
||||
private:
|
||||
int64_t current_stage_;
|
||||
StrategyMap strategy_map_;
|
||||
TensorInfoMap tensor_info_map_;
|
||||
ManualShapeMap manual_shape_map_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
|
|
@ -61,11 +61,13 @@ set(MSLIB_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
|||
|
||||
if(ENABLE_D)
|
||||
list(APPEND MSLIB_INFER_SRC
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc")
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc"
|
||||
"${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc")
|
||||
endif()
|
||||
|
||||
add_library(mindspore_infer_shared_lib_obj OBJECT ${MSLIB_INFER_SRC})
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
||||
|
@ -81,7 +82,7 @@ def _convert_to_layout(param_name, tensor_layout):
|
|||
return strategy
|
||||
|
||||
|
||||
def _load_strategy_file(strategy_filename):
|
||||
def _check_strategy_file(strategy_filename):
|
||||
"""load parallel strategy file"""
|
||||
if not isinstance(strategy_filename, str):
|
||||
raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
|
||||
|
@ -94,18 +95,25 @@ def _load_strategy_file(strategy_filename):
|
|||
if os.path.getsize(strategy_filename) == 0:
|
||||
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} should not "
|
||||
f"be empty. Please check whether the 'strategy_filename' is correct.")
|
||||
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
||||
|
||||
|
||||
def _load_protobuf_strategy(strategy_filename):
|
||||
"""load strategy from protobuf file"""
|
||||
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
||||
with open(strategy_filename, 'rb') as f:
|
||||
pb_content = f.read()
|
||||
parallel_strategy_map.ParseFromString(pb_content)
|
||||
try:
|
||||
parallel_strategy_map.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
raise TypeError("The strategy file type should be one of json or protobuf. "
|
||||
"When the file name extension is not '.json', "
|
||||
"the file is considered as a protobuf file.") from e
|
||||
return parallel_strategy_map
|
||||
|
||||
|
||||
def _build_searched_strategy(strategy_filename):
|
||||
"""build searched strategy"""
|
||||
parallel_strategy_map = _load_strategy_file(strategy_filename)
|
||||
|
||||
def _build_protobuf_strategy(strategy_filename):
|
||||
"""build strategy from protobuf file"""
|
||||
parallel_strategy_map = _load_protobuf_strategy(strategy_filename)
|
||||
layout_items = parallel_strategy_map.parallel_layout_item
|
||||
if not layout_items:
|
||||
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} has no sliced "
|
||||
|
@ -116,10 +124,94 @@ def _build_searched_strategy(strategy_filename):
|
|||
parameter_name = layout_item.param_name
|
||||
layout = layout_item.parallel_layouts
|
||||
strategy[parameter_name] = layout
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
def _build_json_strategy(strategy_filename):
|
||||
"""build strategy from json file"""
|
||||
with open(strategy_filename, 'r') as f:
|
||||
json_content = json.load(f)
|
||||
layout_items = json_content.get("parallel_layout_item")
|
||||
strategy = {}
|
||||
for parameter_name, layout_item in layout_items.items():
|
||||
layout = ms.train.node_strategy_pb2.ParallelLayouts()
|
||||
layout.field = layout_item.get("field")
|
||||
layout.opt_weight_shard_size = layout_item.get("opt_weight_shard_size")
|
||||
layout.opt_weight_shard_step = layout_item.get("opt_weight_shard_step")
|
||||
dev_matrix = layout.dev_matrix.add()
|
||||
for item in layout_item.get("dev_matrix"):
|
||||
dev_matrix.dim.append(item)
|
||||
tensor_map = layout.tensor_map.add()
|
||||
for item in layout_item.get("tensor_map"):
|
||||
tensor_map.dim.append(item)
|
||||
param_split_shape = layout.param_split_shape.add()
|
||||
if "param_split_shape" in layout_item:
|
||||
for item in layout_item.get("param_split_shape"):
|
||||
param_split_shape.dim.append(item)
|
||||
indices_offset = layout.indices_offset.add()
|
||||
if "indices_offset" in layout_item:
|
||||
for item in layout_item.get("indices_offset"):
|
||||
indices_offset.dim.append(item)
|
||||
strategy[parameter_name] = layout
|
||||
return strategy
|
||||
|
||||
|
||||
def _build_searched_strategy(strategy_filename):
|
||||
"""build searched strategy"""
|
||||
_check_strategy_file(strategy_filename)
|
||||
if strategy_filename[-5:] != ".json":
|
||||
return _build_protobuf_strategy(strategy_filename)
|
||||
return _build_json_strategy(strategy_filename)
|
||||
|
||||
|
||||
def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file):
|
||||
"""merge protobuf strategy"""
|
||||
dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
||||
merged_stage = []
|
||||
for src_strategy_file in src_strategy_files:
|
||||
src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
|
||||
strategy_items = src_parallel_strategy_map.parallel_strategy_item
|
||||
layout_items = src_parallel_strategy_map.parallel_layout_item
|
||||
if not strategy_items or not layout_items:
|
||||
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
|
||||
pipeline_stage = strategy_items[0].parallel_strategys.stage
|
||||
if pipeline_stage in merged_stage:
|
||||
continue
|
||||
for layout_item in layout_items:
|
||||
layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
|
||||
dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
|
||||
dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
|
||||
merged_stage.append(pipeline_stage)
|
||||
dst_parallel_strategy_map.current_stage = 1
|
||||
with open(dst_strategy_file, "wb") as f:
|
||||
f.write(dst_parallel_strategy_map.SerializeToString())
|
||||
|
||||
|
||||
def _merge_json_strategy(src_strategy_files, dst_strategy_file):
|
||||
"""merge protobuf strategy"""
|
||||
dst_parallel_strategy_map = {"current_stage": 1, "parallel_strategy_item": {}, "parallel_layout_item": {}}
|
||||
merged_stage = []
|
||||
for src_strategy_file in src_strategy_files:
|
||||
with open(src_strategy_file, 'r') as f:
|
||||
json_content = json.load(f)
|
||||
layout_items = json_content.get("parallel_layout_item")
|
||||
strategy_items = json_content.get("parallel_strategy_item")
|
||||
if not strategy_items or not layout_items:
|
||||
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
|
||||
pipeline_stage = strategy_items.get(list(strategy_items.keys())[0]).get('stage')
|
||||
if pipeline_stage in merged_stage:
|
||||
continue
|
||||
for param_name, layout_item in layout_items.items():
|
||||
new_layout_item = {}
|
||||
new_param_name = "-".join([str(pipeline_stage), param_name])
|
||||
new_layout_item[new_param_name] = layout_item
|
||||
dst_parallel_strategy_map.get("parallel_layout_item").update(new_layout_item)
|
||||
dst_parallel_strategy_map.get("parallel_strategy_item").update(strategy_items)
|
||||
merged_stage.append(pipeline_stage)
|
||||
with open(dst_strategy_file, "w") as f:
|
||||
json.dump(dst_parallel_strategy_map, f)
|
||||
|
||||
|
||||
def _parameter_not_in_local_stage(param_name, origin_strategy_list, strategy_list):
|
||||
"""parameter whether in the local stage"""
|
||||
if origin_strategy_list is None or strategy_list is None:
|
||||
|
|
|
@ -22,8 +22,9 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
|
||||
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, _load_strategy_file, \
|
||||
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num
|
||||
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
|
||||
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
||||
_merge_protobuf_strategy, _merge_json_strategy
|
||||
|
||||
|
||||
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
|
||||
|
@ -55,26 +56,16 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
|
|||
_make_dir(dst_strategy_dir, "path")
|
||||
if not os.path.isdir(src_strategy_dirs):
|
||||
raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs))
|
||||
src_strategy_files = os.path.join(src_strategy_dirs, "*.ckpt")
|
||||
dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
||||
merged_stage = []
|
||||
for src_strategy_file in glob.glob(src_strategy_files):
|
||||
src_parallel_strategy_map = _load_strategy_file(src_strategy_file)
|
||||
strategy_items = src_parallel_strategy_map.parallel_strategy_item
|
||||
layout_items = src_parallel_strategy_map.parallel_layout_item
|
||||
if not strategy_items or not layout_items:
|
||||
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
|
||||
pipeline_stage = strategy_items[0].parallel_strategys.stage
|
||||
if pipeline_stage in merged_stage:
|
||||
continue
|
||||
for layout_item in layout_items:
|
||||
layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
|
||||
dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
|
||||
dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
|
||||
merged_stage.append(pipeline_stage)
|
||||
dst_parallel_strategy_map.current_stage = 1
|
||||
with open(dst_strategy_file, "wb") as f:
|
||||
f.write(dst_parallel_strategy_map.SerializeToString())
|
||||
src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt"))
|
||||
src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json"))
|
||||
if src_strategy_files_protobuf and src_strategy_files_json:
|
||||
raise ValueError("The strategys format should be all '.ckpt' or all '.json'")
|
||||
is_protobuf = len(src_strategy_files_protobuf) > 0
|
||||
if is_protobuf:
|
||||
_merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file)
|
||||
else:
|
||||
_merge_json_strategy(src_strategy_files_json, dst_strategy_file)
|
||||
|
||||
|
||||
|
||||
def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
|
||||
|
@ -106,7 +97,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
|
|||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> checkpoint_files_map = {}
|
||||
>>> for rank in rank_list:
|
||||
>>> checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
|
||||
... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
|
||||
|
||||
"""
|
||||
if not isinstance(rank_id, int):
|
||||
|
@ -167,14 +158,14 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|||
|
||||
Examples:
|
||||
>>> dst_device_num = 8
|
||||
>>> for rank_id in range(dst_device_num)
|
||||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> checkpoint_files_map = {}
|
||||
>>> for rank in rank_list:
|
||||
>>> checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
|
||||
>>> save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
|
||||
>>> transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
|
||||
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> for rank_id in range(dst_device_num):
|
||||
... rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
... checkpoint_files_map = {}
|
||||
... for rank in rank_list:
|
||||
... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
|
||||
... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
|
||||
... transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
|
||||
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
|
||||
"""
|
||||
if not isinstance(checkpoint_files_map, dict):
|
||||
|
|
|
@ -30,7 +30,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
|
|||
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) { return SUCCESS; }
|
||||
const ManualShapeMap &manual_shape_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file,
|
||||
GroupInfoMap *group_info_map) const { return SUCCESS; }
|
||||
|
|
|
@ -201,7 +201,7 @@ def test_six_matmul_save_auto():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.json")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", dataset_strategy="full_batch")
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
@ -256,7 +256,7 @@ def six_matmul_load_auto():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.json")
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy3 = ((2, 2), (2, 2))
|
||||
strategy4 = ((2, 2), (2, 2))
|
||||
|
|
Loading…
Reference in New Issue