!12782 support load mindir and run

From: @wangnan39
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-08 12:38:44 +08:00 committed by Gitee
commit d8ec8f6e95
20 changed files with 351 additions and 55 deletions

View File

@ -66,6 +66,10 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for CNode if is loaded from MindIR.
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
continue;
}
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr);
@ -113,6 +117,69 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
return ret;
}
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr loaded_graph = nullptr;
size_t loaded_graph_num = 0;
auto all_graphs = manager->func_graphs();
for (auto &graph : all_graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->has_attr("is_load")) {
loaded_graph = graph;
loaded_graph_num += 1;
}
}
if (loaded_graph_num == 0) {
return nullptr;
}
if (loaded_graph_num == 1) {
return loaded_graph;
}
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
}
void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs();
size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size();
if (root_inputs_num != loaded_inputs_num) {
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
<< loaded_inputs_num;
}
for (size_t index = 0; index < root_inputs_num; index++) {
auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index];
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type);
if (root_shape->shape() != loaded_shape->shape()) {
MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString();
}
if (root_type->type_id() != loaded_type->type_id()) {
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
<< ", input type of loaded graph: " << loaded_type->ToString();
}
}
}
bool ParseAction(const ResourcePtr &res) {
if (!res->input()) {
MS_LOG(EXCEPTION) << "Parse error";
@ -255,12 +322,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
}
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec = res->args_spec();
auto context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);
// get original loaded graph to check inputs later
auto loaded_graph_ptr = GetLoadedGraph(res);
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
@ -294,7 +363,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
}
}
// check input after abstract when there is a loaded graph
if (loaded_graph_ptr != nullptr) {
CheckRootInputShapeAndType(res, loaded_graph_ptr);
}
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
return true;
}

View File

@ -111,6 +111,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -203,6 +203,19 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
return true;
}
bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting FuncGraph object";
auto func_graph = obj.cast<FuncGraphPtr>();
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
return false;
}
auto new_fg = BasicClone(func_graph);
new_fg->set_attr("is_load", MakeValue(true));
*data = new_fg;
return true;
}
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object";
@ -368,47 +381,21 @@ bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype
}
} // namespace
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
}
bool ret = true;
bool ConvertSingleData(const py::object &obj, ValuePtr *const data) {
MS_EXCEPTION_IF_NULL(data);
ValuePtr converted = nullptr;
if (py::isinstance<py::none>(obj)) {
converted = kNone;
} else if (py::isinstance<py::bool_>(obj)) {
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
} else if (py::isinstance<py::int_>(obj)) {
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::str>(obj)) {
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::ellipsis>(obj)) {
converted = kEllipsis;
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature);
} else if (py::isinstance<py::module>(obj)) {
ConvertNameSpace(obj, &converted);
} else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
ConvertDataClass(obj, &converted);
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
ret = ConvertPrimitive(obj, &converted, use_signature);
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<Type>(obj)) {
converted = obj.cast<TypePtr>();
} else if (py::isinstance<Tensor>(obj)) {
@ -425,9 +412,50 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else {
ret = ConvertOtherObj(obj, &converted);
return false;
}
*data = converted;
return true;
}
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
}
ValuePtr converted = nullptr;
bool ret = ConvertSingleData(obj, &converted);
if (ret) {
*data = converted;
return true;
}
if (py::isinstance<py::int_>(obj)) {
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
ret = ConvertPrimitive(obj, &converted, use_signature);
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<FuncGraph>(obj)) {
ret = ConvertFuncGraph(obj, &converted);
} else {
ret = ConvertOtherObj(obj, &converted);
}
*data = converted;
return ret;
}

View File

@ -113,6 +113,49 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
return para_node;
}
void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto abstract = node->abstract();
if (abstract != nullptr) {
node->set_abstract(abstract->Broaden());
}
}
}
void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
if (!value->isa<FuncGraph>()) {
return;
}
auto resolved_graph = value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(resolved_graph);
if (!resolved_graph->has_attr("is_load")) {
return;
}
auto top_graph = Parser::GetTopFuncGraph();
std::vector<AnfNodePtr> input_params;
for (auto const &param : resolved_graph->parameters()) {
auto param_ptr = dyn_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->has_default()) {
param_ptr->set_func_graph(top_graph);
func_graph->add_used_global_parameters(param_ptr);
// update top_graph
top_graph->add_parameter(param_ptr);
size_t hyper_param_count = top_graph->hyper_param_count();
top_graph->set_hyper_param_count(hyper_param_count + 1);
} else {
input_params.push_back(param_ptr);
}
}
resolved_graph->set_parameters(input_params);
BroadenCNodeAbstract(resolved_graph);
}
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
AnfNodePtr output = nullptr;
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
@ -146,6 +189,7 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return false;
}
MS_EXCEPTION_IF_NULL(convert_result);
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);

View File

@ -48,6 +48,7 @@
#include "pybind_api/pybind_patch.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
#include "load_mindir/load_model.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/constants.h"
#include "ps/util.h"
@ -1096,6 +1097,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
#endif
}
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); }
void ReleaseGeTsd() {
auto context_ptr = MsContext::GetInstance();
if (context_ptr != nullptr) {

View File

@ -140,6 +140,7 @@ void ClearResAtexit();
void ReleaseGeTsd();
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
FuncGraphPtr LoadMindIR(const std::string &file_name);
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,

View File

@ -51,6 +51,9 @@ void ValidateOperation(const AnfNodePtr &node) {
if (abstract::IsInWhiteList(prim)) {
return;
}
if (prim->HasAttr("is_load")) {
return;
}
if (prim->HasPyEvaluator()) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;

View File

@ -273,6 +273,9 @@ class CNode : public AnfNode, public EffectInfoHolder {
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
bool in_forward_flag() const { return in_forward_flag_; }
void set_load_flag(bool is_load) { is_load_ = is_load; }
bool get_load_flag() { return is_load_; }
VarPtr func_graph_as_var() const { return func_graph_as_var_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
@ -304,6 +307,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
bool stop_gradient_;
bool in_forward_flag_ = false;
bool effect_handled_ = false;
bool is_load_ = false;
// inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;

View File

@ -68,6 +68,18 @@ AnfNodePtr FuncGraph::output() const {
}
}
const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
std::vector<AnfNodePtr> input_params;
for (auto const &node : parameters_) {
MS_EXCEPTION_IF_NULL(node);
auto parameter = dyn_cast<Parameter>(node);
if (!parameter->has_default()) {
input_params.push_back(parameter);
}
}
return input_params;
}
ParameterPtr FuncGraph::add_parameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);

View File

@ -160,6 +160,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
abstract::AbstractFunctionPtr abstract();
abstract::AbstractBasePtr ToAbstract() override;
// get function graph inputs, but parameters
const std::vector<AnfNodePtr> get_inputs() const;
// Return the graph's output, or nullptr if not yet deduced.
AnfNodePtr output() const;
void set_output(const AnfNodePtr &value, bool force_new_ret = false);

View File

@ -91,6 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
new_node->set_attrs(old_node->attrs());
new_node->set_load_flag(old_node->get_load_flag());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
new_node->CloneUserData(old_node);

View File

@ -228,17 +228,14 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
}
if (!tensor_proto.has_data_type()) {
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!";
return nullptr;
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!";
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!";
return nullptr;
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!";
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
return tensor_info;
}
@ -253,9 +250,14 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
string debug_info_name = ParseParameterName(parameter_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(parameter_proto.name());
node->set_name(debug_info_name);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(debug_info_name);
tensor_info->set_param_info(param_info);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
@ -284,13 +286,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
string debug_info_name = ParseParameterName(value_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(value_proto.name());
node->set_name(debug_info_name);
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
anfnode_build_map_[value_proto.name()] = node;
@ -300,15 +302,6 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
for (int i = 0; i < importProto.parameter_size(); ++i) {
const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
@ -317,6 +310,15 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
return false;
}
}
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
for (int i = 0; i < importProto.parameter_size(); ++i) {
const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
return true;
}
@ -745,7 +747,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
inputs.push_back(anfnode_build_map_[input_name]);
}
prim->set_attr("is_load", MakeValue(true));
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
@ -777,6 +779,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
cnode_ptr->set_debug_info(debug_info_ptr);
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
cnode_ptr->set_load_flag(true);
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
@ -804,6 +807,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
inputs.push_back(maketuple_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_load_flag(true);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
@ -812,6 +816,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_load_flag(true);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}

View File

@ -20,7 +20,7 @@ Pre-defined building blocks or computing units to construct neural networks.
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
from .learning_rate_schedule import *
from .dynamic_lr import *
from .cell import Cell, GraphKernel
from .cell import Cell, GraphKernel, GraphCell
from .layer import *
from .loss import *
from .optim import *
@ -29,7 +29,7 @@ from .wrap import *
from .sparse import *
__all__ = ["Cell", "GraphKernel"]
__all__ = ["Cell", "GraphKernel", "GraphCell"]
__all__.extend(layer.__all__)
__all__.extend(loss.__all__)
__all__.extend(optim.__all__)

View File

@ -25,7 +25,7 @@ from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.context import ParallelMode
from .. import context
from .._c_expression import init_pipeline, Cell_
from .._c_expression import init_pipeline, Cell_, FuncGraph
from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
@ -1191,3 +1191,39 @@ class GraphKernel(Cell):
def construct(self):
raise NotImplementedError
class GraphCell(Cell):
"""
Base class for running the graph loaded from a MindIR.
This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
Args:
graph (object): A compiled graph loaded from MindIR.
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
"""
def __init__(self, graph):
super(GraphCell, self).__init__(auto_prefix=True)
if not isinstance(graph, FuncGraph):
raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
self.graph = graph
def construct(self, *inputs):
return self.graph(*inputs)
def __call__(self, *inputs):
return self.compile_and_run(*inputs)

View File

@ -22,10 +22,10 @@ from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint"]

View File

@ -141,10 +141,22 @@ class Model:
self._global_rank = _get_global_rank()
self._parameter_broadcast = _get_parameter_broadcast()
self._check_for_graph_cell(kwargs)
self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()
def _check_for_graph_cell(self, kwargs):
if not isinstance(self._network, nn.GraphCell):
return
if self._amp_level != "O0":
logger.warning("amp_level will not work when network is a GraphCell.")
if self._loss_fn is not None or self._optimizer is not None:
raise ValueError("Currently loss_fn and optimizer should be None when network is a GraphCell. ")
if kwargs:
raise ValueError("Currently kwargs should be empty when network is a GraphCell. ")
def _process_amp_args(self, kwargs):
if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False
@ -594,6 +606,8 @@ class Model:
>>> model.train(2, dataset)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
Validator.check_is_int(sink_size)
dataset_size = train_dataset.get_dataset_size()
if dataset_size == 0:
@ -718,9 +732,12 @@ class Model:
>>> acc = model.eval(dataset, dataset_sink_mode=False)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network

View File

@ -38,6 +38,7 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from .._c_expression import load_mindir
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@ -229,6 +230,49 @@ def _check_param_prefix(filter_prefix, param_name):
return False
def load(file_name):
"""
Load MindIR.
The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use
of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
Args:
file_name (str): MindIR file name.
Returns:
Object, a compiled graph that can executed by `GraphCell`.
Raises:
ValueError: MindIR file is incorrect.
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
"""
if not isinstance(file_name, str):
raise ValueError("The file name must be string.")
if not os.path.exists(file_name):
raise ValueError("The file is not exist.")
if not file_name.endswith(".mindir"):
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
logger.info("Execute the process of loading mindir.")
graph = load_mindir(file_name)
if graph is None:
raise RuntimeError("Load MindIR failed.")
return graph
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
"""
Loads checkpoint info from a specified file.

View File

@ -22,7 +22,7 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export
from mindspore.train.serialization import export, load
def weight_variable():
@ -112,3 +112,26 @@ def test_export_lenet_grad_mindir():
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
verify_name = "lenet_grad.mindir"
assert os.path.exists(verify_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
outputs0 = network(inputs0)
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
mindir_name = "test_lenet_load.mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())