forked from OSSInnovation/mindspore
add_tensor_layout_in_stra_ckpt
This commit is contained in:
parent
57fd31b221
commit
60a9fb0001
|
@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; }
|
||||
const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
operator_info->set_outputs_dtype(cnode->Type());
|
||||
operator_info->set_cnode(cnode);
|
||||
// key of strategy map
|
||||
std::string strategy_key_name = NodeParameterName(cnode);
|
||||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = param_names[0].first;
|
||||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
||||
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
||||
|
|
|
@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
// load strategy checkpoint
|
||||
// key of strategy map
|
||||
std::string strategy_key_name = NodeParameterName(cnode);
|
||||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = param_names[0].first;
|
||||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
|
||||
|
@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
|
|||
}
|
||||
}
|
||||
|
||||
std::string NodeParameterName(const CNodePtr &node) {
|
||||
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) {
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
for (auto input : node_inputs) {
|
||||
std::vector<std::pair<std::string, int>> param_names;
|
||||
for (int i = 0; i < UintToInt(node_inputs.size()); ++i) {
|
||||
auto input = node_inputs[i];
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
input_parameter->name();
|
||||
if (ParameterRequireGrad(input_parameter)) {
|
||||
param_names.push_back({input_parameter->name(), i});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "";
|
||||
return param_names;
|
||||
}
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
||||
StrategyMap stra_map;
|
||||
TensorInfoMap tensor_info_map;
|
||||
ManualShapeMap manual_shape_map;
|
||||
auto ret = func_graph->get_return();
|
||||
auto all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
std::string param_name = NodeParameterName(cnode);
|
||||
if (param_name.empty()) {
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
if (param_names.empty()) {
|
||||
continue;
|
||||
}
|
||||
string param_name = param_names[0].first;
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
|
||||
|
@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
||||
StrategyPtr strategyPtr = operator_info->strategy();
|
||||
MS_EXCEPTION_IF_NULL(node->scope());
|
||||
stra_map[param_name] = strategyPtr;
|
||||
for (auto param_name_pair : param_names) {
|
||||
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
||||
continue;
|
||||
}
|
||||
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
|
||||
}
|
||||
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
|
||||
operator_info->name().find(GATHERV2) != std::string::npos) {
|
||||
auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info);
|
||||
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
||||
auto index_offsets = gatherv2_info->index_offsets();
|
||||
if (param_split_shapes.size() != index_offsets.size()) {
|
||||
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
|
||||
}
|
||||
std::vector<std::pair<int32_t, int32_t>> manual_shape;
|
||||
for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) {
|
||||
manual_shape.push_back({param_split_shapes[i], index_offsets[i]});
|
||||
}
|
||||
manual_shape_map[param_name] = manual_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
|
|||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const FuncGraphManagerPtr &manager);
|
||||
|
||||
std::string NodeParameterName(const CNodePtr &node);
|
||||
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph);
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
parallel_strategy_map.set_current_stage(IntToUint(++current_stage_));
|
||||
for (auto &node_stra : strategy_map) {
|
||||
|
@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
|
|||
}
|
||||
}
|
||||
}
|
||||
for (auto &node_tensor_info : tensor_info_map) {
|
||||
TensorInfo tensor_info = node_tensor_info.second;
|
||||
TensorLayout tensor_layout = tensor_info.tensor_layout();
|
||||
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
||||
parallel_layout_item->set_param_name(node_tensor_info.first);
|
||||
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
|
||||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dim : tensor_layout.device_arrangement().array()) {
|
||||
dev_matrix->add_dim(IntToUint(dim));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto dim : tensor_layout.tensor_map().array()) {
|
||||
tensor_map->add_dim(dim);
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
|
||||
MS_EXCEPTION_IF_NULL(manual_shape_map);
|
||||
auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
|
||||
for (auto dim_pair : manual_shape) {
|
||||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
if (!parallel_strategy_map.SerializeToOstream(&output)) {
|
||||
MS_LOG(ERROR) << "Save strategy file failed";
|
||||
|
|
|
@ -19,13 +19,19 @@
|
|||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
|
||||
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
|
||||
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int32_t, int32_t>>>;
|
||||
class StrategyCheckpoint {
|
||||
public:
|
||||
StrategyCheckpoint() {
|
||||
|
@ -38,7 +44,7 @@ class StrategyCheckpoint {
|
|||
~StrategyCheckpoint() = default;
|
||||
|
||||
Status Load(StrategyMap *strategy_map);
|
||||
Status Save(const StrategyMap &strategy_map);
|
||||
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
|
||||
|
||||
static StrategyCheckpoint &GetInstance();
|
||||
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
|
||||
|
|
|
@ -32,7 +32,36 @@ message ParallelStrategyItem {
|
|||
required ParallelStrategys parallel_strategys = 2;
|
||||
}
|
||||
|
||||
message DevMatrix {
|
||||
repeated uint32 dim = 1;
|
||||
}
|
||||
|
||||
message TensorMap {
|
||||
repeated int32 dim = 1;
|
||||
}
|
||||
|
||||
message ParamSplitShape {
|
||||
repeated int64 dim = 1;
|
||||
}
|
||||
|
||||
message IndicesOffset {
|
||||
repeated int64 dim = 1;
|
||||
}
|
||||
|
||||
message ParallelLayouts {
|
||||
repeated DevMatrix dev_matrix = 1;
|
||||
repeated TensorMap tensor_map = 2;
|
||||
repeated ParamSplitShape param_split_shape = 3;
|
||||
repeated IndicesOffset indices_offset = 4;
|
||||
}
|
||||
|
||||
message ParallelLayoutItem {
|
||||
required string param_name = 1;
|
||||
required ParallelLayouts parallel_layouts = 2;
|
||||
}
|
||||
|
||||
message ParallelStrategyMap {
|
||||
required uint32 current_stage = 1;
|
||||
repeated ParallelStrategyItem parallel_strategy_item = 2;
|
||||
repeated ParallelLayoutItem parallel_layout_item = 3;
|
||||
}
|
|
@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
|
|||
|
||||
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
|
||||
|
||||
Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; }
|
||||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) { return SUCCESS; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue