forked from mindspore-Ecosystem/mindspore
!17899 Fix compile cache bug for resent50
Merge pull request !17899 from LiangZhibo/cache
This commit is contained in:
commit
708e56f659
|
@ -28,7 +28,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph);
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
|
||||
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -120,7 +120,7 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap
|
|||
return;
|
||||
}
|
||||
std::ofstream fout(bprop_mindir_realpath.value());
|
||||
mind_ir::ModelProto fg_model = GetBinaryProto(func_graph);
|
||||
mind_ir::ModelProto fg_model = GetBinaryProto(func_graph, false);
|
||||
if (!fg_model.SerializeToOstream(&fout)) {
|
||||
MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
|
||||
<< bprop_mindir_realpath.value() << "\".";
|
||||
|
|
|
@ -75,7 +75,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
|
||||
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
|
||||
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
|
||||
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
py::arg("use_vm") = py::bool_(false), py::arg("queue_name"), "Compile obj by executor.")
|
||||
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
|
||||
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
|
||||
.def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
|
||||
|
|
|
@ -178,7 +178,7 @@ void SetGpuLoopSink(const ResourcePtr &resource) {
|
|||
}
|
||||
}
|
||||
|
||||
void GetCachedFuncGraph(const ResourcePtr &resource) {
|
||||
void GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto realpath = Common::GetRealPath(kCompileCacheFilePath);
|
||||
if (!realpath.has_value()) {
|
||||
|
@ -204,6 +204,14 @@ void GetCachedFuncGraph(const ResourcePtr &resource) {
|
|||
res_mng->AddFuncGraph(fg);
|
||||
fg->set_manager(res_mng);
|
||||
}
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (auto cnode : cnodes) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim != nullptr && prim->HasAttr("shared_name")) {
|
||||
prim->set_attr("shared_name", MakeValue(queue_name));
|
||||
break;
|
||||
}
|
||||
}
|
||||
resource->set_func_graph(fg);
|
||||
}
|
||||
|
||||
|
@ -220,7 +228,7 @@ void CacheFuncGraph(const ResourcePtr &resource) {
|
|||
MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!";
|
||||
}
|
||||
FuncGraphPtr fg = resource->func_graph();
|
||||
mind_ir::ModelProto fg_model = GetBinaryProto(fg);
|
||||
mind_ir::ModelProto fg_model = GetBinaryProto(fg, true);
|
||||
if (!fg_model.SerializeToOstream(&fout)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
|
||||
}
|
||||
|
@ -599,6 +607,11 @@ bool IsPhaseExportAir(const std::string &phase_s) {
|
|||
return phase_s.rfind(phase_to_export) != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsPhaseTrain(const std::string &phase_s) {
|
||||
const std::string phase_to_train = "train";
|
||||
return phase_s.rfind(phase_to_train) != std::string::npos;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
||||
bool is_air = IsPhaseExportAir(phase_s);
|
||||
|
||||
|
@ -627,7 +640,8 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
resource->results()[kBackend] = backend_ptr;
|
||||
// If the 'use_frontend_compile_cache' context has been set true and the cache is read successfully,
|
||||
// do the backend actions only.
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) && resource->func_graph() != nullptr) {
|
||||
if (IsPhaseTrain(phase_s) && MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) &&
|
||||
resource->func_graph() != nullptr) {
|
||||
return BackendPipeline();
|
||||
}
|
||||
return VmPipeline();
|
||||
|
@ -635,7 +649,8 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
return GePipeline();
|
||||
}
|
||||
|
||||
bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) {
|
||||
bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm,
|
||||
const std::string &queue_name) {
|
||||
MS_LOG(DEBUG) << "Start ExecutorPy compile!";
|
||||
if ((!py::isinstance<py::str>(phase))) {
|
||||
MS_LOG(ERROR) << "Arg phase must be string.";
|
||||
|
@ -658,7 +673,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
ResourcePtr resource = std::make_shared<Resource>(obj);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE)) {
|
||||
GetCachedFuncGraph(resource);
|
||||
GetCachedFuncGraph(resource, queue_name);
|
||||
}
|
||||
|
||||
auto p_actions = GetPipeline(resource, phase_s, use_vm);
|
||||
|
@ -680,7 +695,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
executor_info->arg_list_size = size;
|
||||
executor_info->resource = resource;
|
||||
info_[phase_s] = executor_info;
|
||||
pip->Run();
|
||||
pip->Run(phase_s);
|
||||
|
||||
// save the run graph func to MsPipeLine
|
||||
SaveCompiledGraph(phase_s);
|
||||
|
@ -724,11 +739,12 @@ static std::string PrintArgs(const py::tuple &args) {
|
|||
return "";
|
||||
}
|
||||
|
||||
bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) {
|
||||
bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm,
|
||||
const std::string &queue_name) {
|
||||
bool ret_value = false;
|
||||
try {
|
||||
MS_LOG(DEBUG) << PrintArgs(args);
|
||||
ret_value = CompileInner(obj, args, phase, use_vm);
|
||||
ret_value = CompileInner(obj, args, phase, use_vm, queue_name);
|
||||
} catch (const py::error_already_set &ex) {
|
||||
if (!StaticAnalysisException::Instance().HasException()) {
|
||||
// print function call stack info before release
|
||||
|
@ -771,12 +787,12 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
|
|||
return ret_value;
|
||||
}
|
||||
|
||||
void Pipeline::Run() {
|
||||
void Pipeline::Run(const std::string &phase_s) {
|
||||
MS_LOG(INFO) << "Pipeline run";
|
||||
MS_EXCEPTION_IF_NULL(resource_);
|
||||
FuncGraphPtr user_graph = nullptr;
|
||||
|
||||
WITH(MsProfile::GetProfile())[&user_graph, this]() {
|
||||
WITH(MsProfile::GetProfile())[&user_graph, &phase_s, this]() {
|
||||
size_t i = 0;
|
||||
for (auto &action : actions_) {
|
||||
#ifdef ENABLE_TIMELINE
|
||||
|
@ -792,7 +808,7 @@ void Pipeline::Run() {
|
|||
if (action.first == "task_emit") {
|
||||
SetGpuLoopSink(resource_);
|
||||
} else if (action.first == "validate") {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
|
||||
if (IsPhaseTrain(phase_s) && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
|
||||
CacheFuncGraph(resource_);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ class Pipeline {
|
|||
|
||||
~Pipeline() = default;
|
||||
|
||||
void Run();
|
||||
void Run(const std::string &phase_s);
|
||||
|
||||
ResourcePtr resource() { return resource_; }
|
||||
|
||||
|
@ -72,8 +72,10 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
~ExecutorPy();
|
||||
|
||||
void SaveCompiledGraph(const std::string &phase_s);
|
||||
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm,
|
||||
const std::string &queue_name);
|
||||
bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm,
|
||||
const std::string &queue_name);
|
||||
|
||||
void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list);
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class IrExporter {
|
|||
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
|
||||
virtual ~IrExporter() = default;
|
||||
std::string GetDumpString(const FuncGraphPtr &func_graph);
|
||||
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph);
|
||||
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
|
||||
private:
|
||||
IrExportBuilderPtr builder_;
|
||||
|
@ -83,12 +83,14 @@ class IrExportBuilder {
|
|||
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
|
||||
std::string GetProtoString(const FuncGraphPtr &func_graph);
|
||||
void BuildModelInfo();
|
||||
void BuildModel(const FuncGraphPtr &func_graph);
|
||||
void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
mind_ir::ModelProto Model() { return model_; }
|
||||
|
||||
private:
|
||||
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data = false);
|
||||
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data = false);
|
||||
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
|
@ -148,7 +150,7 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
|||
return builder_->GetProtoString(func_graph);
|
||||
}
|
||||
|
||||
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
|
||||
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
||||
MS_LOG(EXCEPTION) << "Input params is null.";
|
||||
}
|
||||
|
@ -157,7 +159,7 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
|
|||
builder_->BuildModelInfo();
|
||||
|
||||
// Export model and return string
|
||||
builder_->BuildModel(func_graph);
|
||||
builder_->BuildModel(func_graph, save_tensor_data);
|
||||
|
||||
return builder_->Model();
|
||||
}
|
||||
|
@ -173,7 +175,7 @@ void IrExportBuilder::BuildModelInfo() {
|
|||
model_.set_model_version("1.1.0");
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
||||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
||||
|
@ -183,21 +185,23 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
|||
while (!todo_.empty()) {
|
||||
FuncGraphPtr fg = todo_.back();
|
||||
todo_.pop_back();
|
||||
BuildFuncGraph(fg, graph_proto);
|
||||
BuildFuncGraph(fg, graph_proto, save_tensor_data);
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data) {
|
||||
// Export parameters
|
||||
// 1. parameters should be mapped to ValueInfoProto
|
||||
// 2. parameters with default value should be mapped to Initializer
|
||||
BuildParameters(func_graph, graph_proto);
|
||||
BuildParameters(func_graph, graph_proto, save_tensor_data);
|
||||
|
||||
// Export operator nodes(include output)
|
||||
BuildNodes(func_graph, graph_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data) {
|
||||
for (auto &item : func_graph->parameters()) {
|
||||
auto param = item->cast<ParameterPtr>();
|
||||
if (param == nullptr) {
|
||||
|
@ -205,10 +209,14 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
}
|
||||
std::string param_name = GetUniqueNodeName(param);
|
||||
if (param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has default.";
|
||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||
parameter_proto->set_name(param_name);
|
||||
SetParamToTensorProto(param, parameter_proto);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
||||
if (tensor && save_tensor_data) {
|
||||
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
|
||||
}
|
||||
} else {
|
||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
input_proto->set_name(param_name);
|
||||
|
@ -388,12 +396,16 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
tensor_proto->set_name(shape_name);
|
||||
SetTensorProto(type, shape, tensor_proto);
|
||||
} else if (type->isa<Number>()) {
|
||||
string shape_name = "shape" + std::to_string(GetTupleIndex());
|
||||
*seq_string += shape_name + ",";
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
|
||||
tensor_proto->add_dims(1);
|
||||
if (type->isa<Bool>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
} else {
|
||||
string shape_name = "shape" + std::to_string(GetTupleIndex());
|
||||
*seq_string += shape_name + ",";
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
|
||||
tensor_proto->add_dims(1);
|
||||
}
|
||||
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||
*seq_string += type->type_name() + ",";
|
||||
} else {
|
||||
|
@ -644,7 +656,17 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
|
|||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value->isa<StringImm>()) {
|
||||
if (value->isa<Int>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
auto int_value = value->cast<IntPtr>();
|
||||
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
||||
} else if (value->isa<Float>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
auto float_value = value->cast<FloatPtr>();
|
||||
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
||||
} else if (value->isa<StringImm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
||||
attr_proto->add_strings(GetValue<std::string>(value));
|
||||
} else if (value->isa<BoolImm>()) {
|
||||
|
@ -751,8 +773,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|||
return exporter->GetDumpString(func_graph);
|
||||
}
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
|
||||
return exporter->GetDumpProto(func_graph);
|
||||
auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
|
||||
return result;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -182,9 +182,9 @@ class _MindSporeFunction:
|
|||
if key not in ms_compile_cache.keys():
|
||||
is_compile = False
|
||||
if self.obj is None:
|
||||
is_compile = self._executor.compile(self.fn, args_list, phase, True)
|
||||
is_compile = self._executor.compile(self.fn, args_list, phase, True, "")
|
||||
else:
|
||||
is_compile = self._executor.compile(self.obj, args_list, phase, True)
|
||||
is_compile = self._executor.compile(self.obj, args_list, phase, True, "")
|
||||
if not is_compile:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
if context.get_context("enable_ge"):
|
||||
|
@ -445,6 +445,7 @@ class _Executor:
|
|||
self._executor = Executor_.get_instance()
|
||||
self.compile_cache = {}
|
||||
self._executor.set_py_exe_path(sys.executable)
|
||||
self.queue_name = ""
|
||||
|
||||
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
||||
input_indexs, phase='dataset'):
|
||||
|
@ -471,6 +472,7 @@ class _Executor:
|
|||
input_indexs=input_indexs,
|
||||
phase=phase):
|
||||
raise RuntimeError("Failure to init and dataset subgraph!")
|
||||
self.queue_name = queue_name
|
||||
return True
|
||||
|
||||
def _build_data_graph(self, obj, phase):
|
||||
|
@ -526,7 +528,7 @@ class _Executor:
|
|||
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
|
||||
result = self._executor.compile(obj, args_list, phase, use_vm)
|
||||
result = self._executor.compile(obj, args_list, phase, use_vm, self.queue_name)
|
||||
self.compile_cache[phase] = phase
|
||||
if not result:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
|
|
|
@ -392,6 +392,14 @@ ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &
|
|||
case mind_ir::AttributeProto_AttributeType_BOOL: {
|
||||
return ParseAttrInScalar_int32_t_bool(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||
const int attr_tensor_type = attr_proto.tensors(index).data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
return {};
|
||||
}
|
||||
return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
|
||||
return {};
|
||||
|
@ -422,6 +430,11 @@ void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto
|
|||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
auto res = ParseAttrInScalarForm(attr_proto, i);
|
||||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) {
|
||||
|
@ -810,7 +823,8 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
prim->set_attr("is_load", MakeValue(true));
|
||||
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
|
||||
CNodePtr cnode_ptr;
|
||||
cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
|
||||
if (kv.size() == 0) {
|
||||
|
@ -818,6 +832,9 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
const ValuePtr kUMonad = std::make_shared<UMonad>();
|
||||
auto monad_abs = kUMonad->ToAbstract();
|
||||
cnode_ptr->set_abstract(monad_abs);
|
||||
} else if (node_type == "Depend") {
|
||||
const ValuePtr kBool = std::make_shared<BoolImm>(true);
|
||||
cnode_ptr->set_abstract(kBool->ToAbstract());
|
||||
} else {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
|
|
|
@ -92,3 +92,4 @@ def test_lenet():
|
|||
label1 = Tensor(np.ones([32]).astype(np.int32))
|
||||
net1 = LeNet()
|
||||
train(net1, data1, label1)
|
||||
context.set_context(save_compile_cache=False, load_compile_cache=False)
|
||||
|
|
|
@ -26,7 +26,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
mind_ir::ModelProto empty_model;
|
||||
return empty_model;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ def get_func_graph(obj, *args, phase="validate"):
|
|||
else:
|
||||
phase = phase_prefix + phase + '.' + str(obj.create_time)
|
||||
_executor = Executor_.get_instance()
|
||||
_executor.compile(obj, args_list, phase, False)
|
||||
_executor.compile(obj, args_list, phase, False, "")
|
||||
return _executor.get_func_graph(phase)
|
||||
|
||||
def test_softmax_relu():
|
||||
|
|
Loading…
Reference in New Issue