!27114 Add LayoutProto to support pipeline for compile cache and solve import package problem.

Merge pull request !27114 from LiangZhibo/mindir
This commit is contained in:
i-robot 2021-12-03 02:53:12 +00:00 committed by Gitee
commit 8a2ce38ce1
13 changed files with 240 additions and 33 deletions

View File

@ -32,7 +32,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph);
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &param_layout_fg = nullptr);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph) {
py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::dict dict;
std::vector<AnfNodePtr> graph_params = graph->parameters();
@ -44,7 +44,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto slice_shape = tensor_layout->slice_shape().array();
int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split();
std::string opt_shard_group = tensor_layout->opt_shard_group();
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
py::tuple layout =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
dict[py::str(name)] = layout;
@ -54,6 +54,25 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
return dict;
}
py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
py::dict dict;
auto layout_map = resource->get_layout_map();
for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
auto name = iter->first;
auto layout = iter->second;
auto device_arrangement = layout->get_device_arrangement();
auto tensor_map = layout->get_tensor_map();
auto slice_shape = layout->get_slice_shape();
int64_t field_size = layout->get_field_size();
bool uniform_split = layout->get_uniform_split();
const std::string &opt_shard_group = layout->get_opt_shard_group();
py::tuple layout_tuple =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
dict[py::str(name)] = layout_tuple;
}
return dict;
}
py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::dict dict;
@ -81,7 +100,7 @@ py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
}
// In pipeline parallel mode, many parameters are not used and need to be deleted
py::list GetParallelParameterNameList(const FuncGraphPtr &graph) {
py::list GetParallelParameterNameListFromGraph(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::list parallel_parameter_name_list;
@ -95,5 +114,15 @@ py::list GetParallelParameterNameList(const FuncGraphPtr &graph) {
}
return parallel_parameter_name_list;
}
py::list GetParallelParameterNameListFromResource(const pipeline::ResourcePtr &resource) {
auto &layout_map = resource->get_layout_map();
py::list parallel_parameter_name_list;
for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
auto name = iter->first;
parallel_parameter_name_list.append(name);
}
return parallel_parameter_name_list;
}
} // namespace parallel
} // namespace mindspore

View File

@ -20,13 +20,16 @@
#include "pybind11/stl.h"
#include "pybind11/pybind11.h"
#include "ir/anf.h"
#include "pipeline/jit/resource.h"
namespace py = pybind11;
namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph);
py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph);
py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource);
py::dict GetAllreduceFusion(const FuncGraphPtr &graph);
py::list GetParallelParameterNameList(const FuncGraphPtr &graph);
py::list GetParallelParameterNameListFromGraph(const FuncGraphPtr &graph);
py::list GetParallelParameterNameListFromResource(const pipeline::ResourcePtr &resource);
} // namespace parallel
} // namespace mindspore

View File

@ -298,7 +298,8 @@ std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
return ret;
}
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool need_layout) {
const size_t idx = resource->compile_cache_id();
std::string compile_cache_path = GetCompileCachePath(idx);
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
if (!realpath.has_value()) {
@ -315,7 +316,12 @@ FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
MindIRLoader mindir_loader;
mindir_loader.set_need_renormalize(false);
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
return mindir_loader.LoadMindIR(realpath.value());
mindir_loader.set_need_layout(need_layout);
auto fg = mindir_loader.LoadMindIR(realpath.value());
if (need_layout) {
resource->set_layout_map(mindir_loader.get_layout_map());
}
return fg;
}
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &weights, const std::string &queue_name) {
@ -326,8 +332,14 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &wei
return nullptr;
}
// Determine whether to load layout information.
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
bool need_layout = false;
if ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL)) {
need_layout = true;
}
// Load the compilation cache file.
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id(), weights);
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, need_layout);
if (fg == nullptr) {
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
return nullptr;
@ -356,7 +368,7 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &wei
return fg;
}
bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, size_t idx) {
bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) {
std::string compile_cache_path = GetCompileCachePath(idx);
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
if (!realpath.has_value()) {
@ -370,7 +382,7 @@ bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, size_t idx) {
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
return false;
}
ModelProtoPtr fg_model = GetBinaryProto(fg);
ModelProtoPtr fg_model = GetBinaryProto(fg, layout_fg);
if (fg_model == nullptr) {
MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
fout.close();
@ -413,7 +425,13 @@ void CacheFuncGraph(const ResourcePtr &resource) {
MS_LOG(ERROR) << "The func_graph to be cached is null.";
return;
}
if (!ExportFuncGraphToMindIR(fg, resource->compile_cache_id())) {
FuncGraphPtr layout_fg = nullptr;
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (fg->has_flag(parallel::AUTO_PARALLEL) &&
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
layout_fg = resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
}
if (!ExportFuncGraphToMindIR(fg, layout_fg, resource->compile_cache_id())) {
MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString();
return;
}
@ -524,7 +542,8 @@ ResourcePtr GraphExecutorPy::GetResource(const std::string &phase) {
FuncGraphPtr GraphExecutorPy::GetFuncGraph(const std::string &phase) {
if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No executor info. found for phase: " << phase;
MS_LOG(INFO) << "No executor info. found for phase: " << phase;
return nullptr;
}
return info_[phase]->func_graph;
}
@ -636,7 +655,11 @@ py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
MS_LOG(DEBUG) << "GetParameterLayout!";
std::string layout_graph = phase + kStepParallelGraph;
auto graph = GetFuncGraph(layout_graph);
return mindspore::parallel::GetParameterLayout(graph);
if (graph == nullptr) {
auto resource = info_[phase]->resource;
return mindspore::parallel::GetParameterLayoutFromResource(resource);
}
return mindspore::parallel::GetParameterLayoutFromGraph(graph);
}
py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
@ -647,7 +670,11 @@ py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase) {
std::string param_graph = phase + kStepParallelGraph;
auto graph = GetFuncGraph(param_graph);
return mindspore::parallel::GetParallelParameterNameList(graph);
if (graph == nullptr) {
auto resource = info_[phase]->resource;
return mindspore::parallel::GetParallelParameterNameListFromResource(resource);
}
return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
}
void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
@ -840,11 +867,15 @@ void GraphExecutorPy::SaveCompiledGraph(const std::string &phase) {
if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
func_graph = info_[phase]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
std::string layout_graph = phase + kStepParallelGraph;
executor_info->func_graph = func_graph;
info_[layout_graph] = executor_info;
auto results = info_[phase]->resource->results();
// When using frontend compile cache, model parallel parameter layout graph is not saved.
if (results.find(kStepParallelGraph) != results.end()) {
func_graph = results[kStepParallelGraph].cast<FuncGraphPtr>();
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
std::string layout_graph = phase + kStepParallelGraph;
executor_info->func_graph = func_graph;
info_[layout_graph] = executor_info;
}
} else {
MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
}

View File

@ -34,6 +34,7 @@
#include "pipeline/jit/resource_base.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "load_mindir/load_model.h"
namespace mindspore {
namespace pipeline {
@ -87,6 +88,9 @@ class Resource : public ResourceBase {
bool vm_loop_flag() { return vm_loop_flag_; }
int64_t loop_size() { return loop_size_; }
void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; }
const LayoutMap get_layout_map() const { return layout_map_; }
bool enable_compile_cache() { return enable_compile_cache_; }
void set_enable_compile_cache(bool enable_compile_cache) { enable_compile_cache_ = enable_compile_cache; }
@ -118,6 +122,7 @@ class Resource : public ResourceBase {
bool enable_compile_cache_{false};
size_t compile_cache_id_{0};
std::string compile_cache_dep_files_hash_;
LayoutMap layout_map_{};
};
using ResourcePtr = std::shared_ptr<pipeline::Resource>;

View File

@ -30,6 +30,7 @@
#include "debug/dump_proto.h"
#include "utils/ms_utils.h"
#include "utils/utils.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
@ -83,7 +84,7 @@ class IrExporter {
explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {}
virtual ~IrExporter() = default;
std::string GetDumpString(const FuncGraphPtr &func_graph);
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph);
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &param_layout_fg = nullptr);
private:
IrExportBuilderPtr builder_;
@ -98,6 +99,8 @@ class IrExportBuilder {
bool BuildModel(const FuncGraphPtr &func_graph);
ModelProtoPtr Model() { return model_; }
void BuildLayout(const FuncGraphPtr &func_graph);
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
@ -161,7 +164,7 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
return builder_->GetProtoString();
}
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &param_layout_fg) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
@ -173,6 +176,11 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
if (!builder_->BuildModel(func_graph)) {
return nullptr;
}
// Export layout information
if (param_layout_fg) {
builder_->BuildLayout(param_layout_fg);
}
return builder_->Model();
}
@ -190,6 +198,44 @@ void IrExportBuilder::BuildModelInfo() {
model_->set_little_endian(common::IsLittleByteOrder());
}
void IrExportBuilder::BuildLayout(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> graph_params = func_graph->parameters();
mind_ir::ParallelProto *parallel_proto = model_->mutable_parallel();
for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name();
auto tensor_layout = para->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else {
mind_ir::LayoutProto *layoutProto = parallel_proto->add_layout();
// Get all the information for layput
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array();
int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split();
std::string opt_shard_group = tensor_layout->opt_shard_group();
// Save all information to Layout Proto
layoutProto->set_name(name);
for (auto device_arrangement_element : device_arrangement) {
layoutProto->add_device_arrangement_int(device_arrangement_element);
}
for (auto tensor_map_element : tensor_map) {
layoutProto->add_tensor_map_int(tensor_map_element);
}
for (auto slice_shape_element : slice_shape) {
layoutProto->add_slice_shape_int(slice_shape_element);
}
layoutProto->set_field_size(field_size);
layoutProto->set_uniform_split(uniform_split);
layoutProto->set_opt_shard_group(opt_shard_group);
}
}
}
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
mind_ir::GraphProto *graph_proto = model_->mutable_graph();
@ -1089,12 +1135,13 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
if (exporter == nullptr) {
return "";
}
return exporter->GetDumpString(func_graph);
auto ret = exporter->GetDumpString(func_graph);
return ret;
}
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &param_layout_fg) {
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
auto result = exporter->GetDumpProto(func_graph);
auto result = exporter->GetDumpProto(func_graph, param_layout_fg);
return result;
}
} // namespace mindspore

View File

@ -124,7 +124,7 @@ def _get_filename_from_trace(trace):
return filename
def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files):
def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files, pkg):
"""Get the dependency files of the network"""
with open(file_path) as fh:
root = ast.parse(fh.read(), file_path)
@ -132,6 +132,8 @@ def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_f
module_name = ""
if isinstance(node, ast.ImportFrom):
module_name = node.module
if node.level == 1:
module_name = "." + module_name
elif not isinstance(node, ast.Import):
continue
# Do not care the files in mindspore package
@ -148,10 +150,10 @@ def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_f
if not n.name is None:
whole_module += "." + n.name
try:
module_spec = importlib.util.find_spec(whole_module)
module_spec = importlib.util.find_spec(whole_module, pkg)
except (ModuleNotFoundError, ValueError):
whole_module = whole_module[0:whole_module.rfind('.')]
module_spec = importlib.util.find_spec(whole_module)
module_spec = importlib.util.find_spec(whole_module, pkg)
if module_spec is None:
continue
module = importlib.util.module_from_spec(module_spec)
@ -162,7 +164,8 @@ def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_f
if not dep_file_path.startswith(python_bin_dir) and not dep_file_path in compile_cache_dep_files:
logger.debug(f"dependent file path: {dep_file_path}")
compile_cache_dep_files.append(dep_file_path)
__get_compile_cache_dep_files(dep_file_path, python_bin_dir, compile_cache_dep_files)
pkg = module.__package__
__get_compile_cache_dep_files(dep_file_path, python_bin_dir, compile_cache_dep_files, pkg)
def _get_compile_cache_dep_files():
@ -187,7 +190,7 @@ def _get_compile_cache_dep_files():
file_path = os.path.realpath(filename)
logger.debug(f"entry script file path: {file_path}")
compile_cache_dep_files.append(file_path)
__get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files)
__get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files, None)
return compile_cache_dep_files

View File

@ -1337,6 +1337,41 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
return dstGraph;
}
const LayoutMap MSANFModelParser::ParseLayout(const mind_ir::ModelProto &model_proto) {
LayoutMap ret;
mind_ir::ParallelProto parallel_proto = model_proto.parallel();
for (int i = 0; i < parallel_proto.layout_size(); ++i) {
const mind_ir::LayoutProto &layout_proto = parallel_proto.layout(i);
LayoutPtr cur_layout = std::make_shared<Layout>();
const std::string name = layout_proto.name();
std::vector<int64_t> device_arrangement;
for (int num = 0; num < layout_proto.device_arrangement_int_size(); ++num) {
device_arrangement.emplace_back(layout_proto.device_arrangement_int(num));
}
std::vector<int64_t> tensor_map;
for (int num = 0; num < layout_proto.tensor_map_int_size(); ++num) {
tensor_map.emplace_back(layout_proto.tensor_map_int(num));
}
std::vector<int64_t> slice_shape;
for (int num = 0; num < layout_proto.slice_shape_int_size(); ++num) {
slice_shape.emplace_back(layout_proto.slice_shape_int(num));
}
int64_t field_size = layout_proto.field_size();
bool uniform_spilt = layout_proto.uniform_split();
const std::string opt_shard_group = layout_proto.opt_shard_group();
cur_layout->set_device_arrangement(device_arrangement);
cur_layout->set_tensor_map(tensor_map);
cur_layout->set_slice_shape(slice_shape);
cur_layout->set_field_size(field_size);
cur_layout->set_uniform_split(uniform_spilt);
cur_layout->set_opt_shard_group(opt_shard_group);
ret[name] = cur_layout;
}
return ret;
}
AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
auto it = anfnode_build_map_.find(node_name);
if (it == anfnode_build_map_.end()) {

View File

@ -25,6 +25,7 @@
#include "ir/func_graph.h"
#include "proto/mind_ir.pb.h"
#include "utils/crypto.h"
#include "load_mindir/load_model.h"
namespace mindspore {
using int32 = int32_t;
using int64 = int64_t;
@ -36,6 +37,7 @@ class MSANFModelParser {
static void LoadTensorMapClear() { load_tensor_map_.clear(); }
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map<std::string, ValuePtr> &weights = {});
const LayoutMap ParseLayout(const mind_ir::ModelProto &model_proto);
bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
std::string GetProducerName() { return producer_name_; }

View File

@ -230,6 +230,9 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
model_parser.SetLite();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_);
if (need_layout_) {
layout_map_ = model_parser.ParseLayout(origin_model);
}
return dstgraph_ptr;
}

View File

@ -25,6 +25,36 @@
#include "proto/mind_ir.pb.h"
namespace mindspore {
class Layout {
public:
Layout() = default;
const std::vector<int64_t> &get_device_arrangement() const { return device_arrangement_; }
void set_device_arrangement(const std::vector<int64_t> &device_arrangement) {
device_arrangement_ = device_arrangement;
}
const std::vector<int64_t> &get_tensor_map() const { return tensor_map_; }
void set_tensor_map(const std::vector<int64_t> &tensor_map) { tensor_map_ = tensor_map; }
const std::vector<int64_t> &get_slice_shape() const { return slice_shape_; }
void set_slice_shape(const std::vector<int64_t> &slice_shape) { slice_shape_ = slice_shape; }
int64_t get_field_size() const { return field_size_; }
void set_field_size(int64_t field_size) { field_size_ = field_size; }
bool get_uniform_split() const { return uniform_split_; }
void set_uniform_split(bool uniform_split) { uniform_split_ = uniform_split; }
const std::string &get_opt_shard_group() const { return opt_shard_group_; }
void set_opt_shard_group(const std::string &opt_shard_group) { opt_shard_group_ = opt_shard_group; }
private:
std::vector<int64_t> device_arrangement_{};
std::vector<int64_t> tensor_map_{};
std::vector<int64_t> slice_shape_{};
int64_t field_size_ = 0;
bool uniform_split_ = false;
std::string opt_shard_group_ = "";
};
using LayoutPtr = std::shared_ptr<Layout>;
using LayoutMap = std::map<string, LayoutPtr>;
class MindIRLoader {
public:
MindIRLoader() = default;
@ -35,9 +65,11 @@ class MindIRLoader {
bool get_need_renormalize() const { return need_renormalize_; }
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
void set_need_layout(bool need_layout) { need_layout_ = need_layout; }
void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
weights_value_map_ = weights_value_map;
}
LayoutMap get_layout_map() { return layout_map_; }
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
FuncGraphPtr LoadMindIR(const std::string &file_name);
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
@ -52,6 +84,8 @@ class MindIRLoader {
bool inc_load_ = false;
bool need_renormalize_ = true;
std::map<string, ValuePtr> weights_value_map_;
bool need_layout_ = false;
LayoutMap layout_map_;
};
std::string LoadPreprocess(const std::string &file_name);

View File

@ -76,7 +76,8 @@ message ModelProto {
optional GraphProto graph = 7;
repeated GraphProto functions = 8; // all the graphs without the main graph.
optional string preprocessor = 9; // data graph from MindData.
optional bool little_endian = 10; // bytes order in load device
optional bool little_endian = 10; // bytes order in load device.
optional ParallelProto parallel = 11; // information for parallel.
}
@ -136,4 +137,16 @@ message TensorProto {
optional string ref_key = 13;
}
message ParallelProto {
repeated LayoutProto layout = 1;
}
message LayoutProto {
optional string name = 1;
repeated int64 device_arrangement_int = 2;
repeated int64 tensor_map_int = 3;
repeated int64 slice_shape_int = 4;
optional int64 field_size = 5;
optional bool uniform_split = 6;
optional string opt_shard_group = 7;
}

View File

@ -24,9 +24,11 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return "";
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
return "";
}
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &layout_fg) {
ModelProtoPtr empty_model;
return empty_model;
}