forked from mindspore-Ecosystem/mindspore
!12782 support load mindir and run
From: @wangnan39 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
d8ec8f6e95
|
@ -66,6 +66,10 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
|||
for (auto &node : manager->all_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||
// Keep previous inferred value for CNode if is loaded from MindIR.
|
||||
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
||||
continue;
|
||||
}
|
||||
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
||||
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
||||
node->set_abstract(nullptr);
|
||||
|
@ -113,6 +117,69 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|||
return ret;
|
||||
}
|
||||
|
||||
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr loaded_graph = nullptr;
|
||||
size_t loaded_graph_num = 0;
|
||||
auto all_graphs = manager->func_graphs();
|
||||
for (auto &graph : all_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->has_attr("is_load")) {
|
||||
loaded_graph = graph;
|
||||
loaded_graph_num += 1;
|
||||
}
|
||||
}
|
||||
if (loaded_graph_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (loaded_graph_num == 1) {
|
||||
return loaded_graph;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
|
||||
}
|
||||
|
||||
void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr root_graph = *(manager->roots().begin());
|
||||
auto root_inputs = root_graph->get_inputs();
|
||||
auto loaded_inputs = loaded_graph->get_inputs();
|
||||
|
||||
size_t root_inputs_num = root_inputs.size();
|
||||
size_t loaded_inputs_num = loaded_inputs.size();
|
||||
if (root_inputs_num != loaded_inputs_num) {
|
||||
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
|
||||
<< loaded_inputs_num;
|
||||
}
|
||||
for (size_t index = 0; index < root_inputs_num; index++) {
|
||||
auto root_input = root_inputs[index];
|
||||
auto loaded_input = loaded_inputs[index];
|
||||
|
||||
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
||||
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
||||
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
||||
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
||||
MS_EXCEPTION_IF_NULL(root_shape);
|
||||
MS_EXCEPTION_IF_NULL(loaded_shape);
|
||||
MS_EXCEPTION_IF_NULL(root_type);
|
||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||
|
||||
if (root_shape->shape() != loaded_shape->shape()) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index
|
||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||
}
|
||||
if (root_type->type_id() != loaded_type->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
|
||||
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
|
||||
<< ", input type of loaded graph: " << loaded_type->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ParseAction(const ResourcePtr &res) {
|
||||
if (!res->input()) {
|
||||
MS_LOG(EXCEPTION) << "Parse error";
|
||||
|
@ -255,12 +322,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
||||
}
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
||||
auto context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
||||
context->ParallelParameterContextInitShape(func_graph);
|
||||
|
||||
// get original loaded graph to check inputs later
|
||||
auto loaded_graph_ptr = GetLoadedGraph(res);
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
|
@ -294,7 +363,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check input after abstract when there is a loaded graph
|
||||
if (loaded_graph_ptr != nullptr) {
|
||||
CheckRootInputShapeAndType(res, loaded_graph_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -111,6 +111,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
|
||||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -203,6 +203,19 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting FuncGraph object";
|
||||
auto func_graph = obj.cast<FuncGraphPtr>();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
|
||||
return false;
|
||||
}
|
||||
auto new_fg = BasicClone(func_graph);
|
||||
new_fg->set_attr("is_load", MakeValue(true));
|
||||
*data = new_fg;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting slice object";
|
||||
|
||||
|
@ -368,47 +381,21 @@ bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
|
||||
// check parameter valid
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data is null pointer";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ret = true;
|
||||
bool ConvertSingleData(const py::object &obj, ValuePtr *const data) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
ValuePtr converted = nullptr;
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
converted = kNone;
|
||||
} else if (py::isinstance<py::bool_>(obj)) {
|
||||
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
ret = ConvertDict(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<py::slice>(obj)) {
|
||||
ret = ConvertSlice(obj, &converted);
|
||||
} else if (py::isinstance<py::ellipsis>(obj)) {
|
||||
converted = kEllipsis;
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
ret = ConvertTuple(obj, &converted, use_signature);
|
||||
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
|
||||
ret = ConvertCellList(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<Cell>(obj)) {
|
||||
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
|
||||
} else if (py::isinstance<py::list>(obj)) {
|
||||
ret = ConvertList(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<py::module>(obj)) {
|
||||
ConvertNameSpace(obj, &converted);
|
||||
} else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
|
||||
ConvertDataClass(obj, &converted);
|
||||
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
|
||||
ret = ConvertPrimitive(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<MetaFuncGraph>(obj)) {
|
||||
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<Type>(obj)) {
|
||||
converted = obj.cast<TypePtr>();
|
||||
} else if (py::isinstance<Tensor>(obj)) {
|
||||
|
@ -425,9 +412,50 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
|
||||
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
||||
} else {
|
||||
ret = ConvertOtherObj(obj, &converted);
|
||||
return false;
|
||||
}
|
||||
*data = converted;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
|
||||
// check parameter valid
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data is null pointer";
|
||||
return false;
|
||||
}
|
||||
|
||||
ValuePtr converted = nullptr;
|
||||
bool ret = ConvertSingleData(obj, &converted);
|
||||
if (ret) {
|
||||
*data = converted;
|
||||
return true;
|
||||
}
|
||||
if (py::isinstance<py::int_>(obj)) {
|
||||
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
ret = ConvertDict(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<py::slice>(obj)) {
|
||||
ret = ConvertSlice(obj, &converted);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
ret = ConvertTuple(obj, &converted, use_signature);
|
||||
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
|
||||
ret = ConvertCellList(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<Cell>(obj)) {
|
||||
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
|
||||
} else if (py::isinstance<py::list>(obj)) {
|
||||
ret = ConvertList(obj, &converted, use_signature);
|
||||
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
|
||||
ret = ConvertPrimitive(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<MetaFuncGraph>(obj)) {
|
||||
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<FuncGraph>(obj)) {
|
||||
ret = ConvertFuncGraph(obj, &converted);
|
||||
} else {
|
||||
ret = ConvertOtherObj(obj, &converted);
|
||||
}
|
||||
*data = converted;
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -113,6 +113,49 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
return para_node;
|
||||
}
|
||||
|
||||
void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) {
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto abstract = node->abstract();
|
||||
if (abstract != nullptr) {
|
||||
node->set_abstract(abstract->Broaden());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||
if (!value->isa<FuncGraph>()) {
|
||||
return;
|
||||
}
|
||||
auto resolved_graph = value->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(resolved_graph);
|
||||
if (!resolved_graph->has_attr("is_load")) {
|
||||
return;
|
||||
}
|
||||
auto top_graph = Parser::GetTopFuncGraph();
|
||||
std::vector<AnfNodePtr> input_params;
|
||||
for (auto const ¶m : resolved_graph->parameters()) {
|
||||
auto param_ptr = dyn_cast<Parameter>(param);
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->has_default()) {
|
||||
param_ptr->set_func_graph(top_graph);
|
||||
func_graph->add_used_global_parameters(param_ptr);
|
||||
|
||||
// update top_graph
|
||||
top_graph->add_parameter(param_ptr);
|
||||
size_t hyper_param_count = top_graph->hyper_param_count();
|
||||
top_graph->set_hyper_param_count(hyper_param_count + 1);
|
||||
} else {
|
||||
input_params.push_back(param_ptr);
|
||||
}
|
||||
}
|
||||
resolved_graph->set_parameters(input_params);
|
||||
BroadenCNodeAbstract(resolved_graph);
|
||||
}
|
||||
|
||||
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
|
||||
AnfNodePtr output = nullptr;
|
||||
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
|
@ -146,6 +189,7 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
|
|||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(convert_result);
|
||||
ConvertLoadedGraph(func_graph, convert_result);
|
||||
output = NewValueNode(convert_result);
|
||||
if (convert_result->isa<tensor::Tensor>()) {
|
||||
output = GetMixedPrecisionCastHelp(func_graph, output);
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "pybind_api/pybind_patch.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -1096,6 +1097,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
|
|||
#endif
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); }
|
||||
|
||||
void ReleaseGeTsd() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
if (context_ptr != nullptr) {
|
||||
|
|
|
@ -140,6 +140,7 @@ void ClearResAtexit();
|
|||
void ReleaseGeTsd();
|
||||
|
||||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
|
|
@ -51,6 +51,9 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
if (abstract::IsInWhiteList(prim)) {
|
||||
return;
|
||||
}
|
||||
if (prim->HasAttr("is_load")) {
|
||||
return;
|
||||
}
|
||||
if (prim->HasPyEvaluator()) {
|
||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
||||
return;
|
||||
|
|
|
@ -273,6 +273,9 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
|
||||
bool in_forward_flag() const { return in_forward_flag_; }
|
||||
|
||||
void set_load_flag(bool is_load) { is_load_ = is_load; }
|
||||
bool get_load_flag() { return is_load_; }
|
||||
|
||||
VarPtr func_graph_as_var() const { return func_graph_as_var_; }
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
|
@ -304,6 +307,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
bool stop_gradient_;
|
||||
bool in_forward_flag_ = false;
|
||||
bool effect_handled_ = false;
|
||||
bool is_load_ = false;
|
||||
// inputs_value_ store cnode input value and id in pynative mode
|
||||
// output_value_ store cnode value and id in pynative mode
|
||||
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
|
||||
|
|
|
@ -68,6 +68,18 @@ AnfNodePtr FuncGraph::output() const {
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
|
||||
std::vector<AnfNodePtr> input_params;
|
||||
for (auto const &node : parameters_) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto parameter = dyn_cast<Parameter>(node);
|
||||
if (!parameter->has_default()) {
|
||||
input_params.push_back(parameter);
|
||||
}
|
||||
}
|
||||
return input_params;
|
||||
}
|
||||
|
||||
ParameterPtr FuncGraph::add_parameter() {
|
||||
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
||||
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
|
||||
|
|
|
@ -160,6 +160,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
abstract::AbstractFunctionPtr abstract();
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
// get function graph inputs, but parameters
|
||||
const std::vector<AnfNodePtr> get_inputs() const;
|
||||
// Return the graph's output, or nullptr if not yet deduced.
|
||||
AnfNodePtr output() const;
|
||||
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
|
||||
|
|
|
@ -91,6 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
|||
new_node->set_forward(old_node->forward().first, old_node->forward().second);
|
||||
new_node->set_inputs_value(old_node->inputs_value());
|
||||
new_node->set_attrs(old_node->attrs());
|
||||
new_node->set_load_flag(old_node->get_load_flag());
|
||||
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
|
||||
new_node->set_scope(scope);
|
||||
new_node->CloneUserData(old_node);
|
||||
|
|
|
@ -228,17 +228,14 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
|
|||
}
|
||||
|
||||
if (!tensor_proto.has_data_type()) {
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!";
|
||||
}
|
||||
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!";
|
||||
}
|
||||
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
return tensor_info;
|
||||
}
|
||||
|
||||
|
@ -253,9 +250,14 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|||
string debug_info_name = ParseParameterName(parameter_proto.name());
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(parameter_proto.name());
|
||||
node->set_name(debug_info_name);
|
||||
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
||||
param_info->set_name(debug_info_name);
|
||||
tensor_info->set_param_info(param_info);
|
||||
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
@ -284,13 +286,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
string debug_info_name = ParseParameterName(value_proto.name());
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(value_proto.name());
|
||||
node->set_name(debug_info_name);
|
||||
|
||||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
|
@ -300,15 +302,6 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
|
||||
for (int i = 0; i < importProto.parameter_size(); ++i) {
|
||||
const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
|
||||
|
@ -317,6 +310,15 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
|
||||
for (int i = 0; i < importProto.parameter_size(); ++i) {
|
||||
const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -745,7 +747,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
|
||||
prim->set_attr("is_load", MakeValue(true));
|
||||
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
|
||||
|
@ -777,6 +779,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
cnode_ptr->set_debug_info(debug_info_ptr);
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
cnode_ptr->set_load_flag(true);
|
||||
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
|
@ -804,6 +807,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_load_flag(true);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
|
@ -812,6 +816,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_load_flag(true);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ Pre-defined building blocks or computing units to construct neural networks.
|
|||
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
|
||||
from .learning_rate_schedule import *
|
||||
from .dynamic_lr import *
|
||||
from .cell import Cell, GraphKernel
|
||||
from .cell import Cell, GraphKernel, GraphCell
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .optim import *
|
||||
|
@ -29,7 +29,7 @@ from .wrap import *
|
|||
from .sparse import *
|
||||
|
||||
|
||||
__all__ = ["Cell", "GraphKernel"]
|
||||
__all__ = ["Cell", "GraphKernel", "GraphCell"]
|
||||
__all__.extend(layer.__all__)
|
||||
__all__.extend(loss.__all__)
|
||||
__all__.extend(optim.__all__)
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore import log as logger
|
|||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.context import ParallelMode
|
||||
from .. import context
|
||||
from .._c_expression import init_pipeline, Cell_
|
||||
from .._c_expression import init_pipeline, Cell_, FuncGraph
|
||||
from .._checkparam import Validator
|
||||
from ..common import dtype as mstype
|
||||
from ..common.api import _executor, _pynative_exec
|
||||
|
@ -1191,3 +1191,39 @@ class GraphKernel(Cell):
|
|||
|
||||
def construct(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GraphCell(Cell):
|
||||
"""
|
||||
Base class for running the graph loaded from a MindIR.
|
||||
|
||||
This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
|
||||
diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
|
||||
|
||||
Args:
|
||||
graph (object): A compiled graph loaded from MindIR.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train import export, load
|
||||
>>>
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3)
|
||||
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
>>> export(net, input, file_name="net", file_format="MINDIR")
|
||||
>>> graph = load("net.mindir")
|
||||
>>> net = nn.GraphCell(graph)
|
||||
>>> output = net(input)
|
||||
"""
|
||||
def __init__(self, graph):
|
||||
super(GraphCell, self).__init__(auto_prefix=True)
|
||||
if not isinstance(graph, FuncGraph):
|
||||
raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
|
||||
self.graph = graph
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.graph(*inputs)
|
||||
|
||||
def __call__(self, *inputs):
|
||||
return self.compile_and_run(*inputs)
|
||||
|
|
|
@ -22,10 +22,10 @@ from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
|||
from . import amp
|
||||
from .amp import build_train_network
|
||||
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
||||
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
|
||||
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
|
||||
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint
|
||||
|
||||
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
|
||||
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
||||
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
|
||||
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
|
||||
"load_distributed_checkpoint"]
|
||||
|
|
|
@ -141,10 +141,22 @@ class Model:
|
|||
self._global_rank = _get_global_rank()
|
||||
self._parameter_broadcast = _get_parameter_broadcast()
|
||||
|
||||
self._check_for_graph_cell(kwargs)
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _check_for_graph_cell(self, kwargs):
|
||||
if not isinstance(self._network, nn.GraphCell):
|
||||
return
|
||||
if self._amp_level != "O0":
|
||||
logger.warning("amp_level will not work when network is a GraphCell.")
|
||||
|
||||
if self._loss_fn is not None or self._optimizer is not None:
|
||||
raise ValueError("Currently loss_fn and optimizer should be None when network is a GraphCell. ")
|
||||
if kwargs:
|
||||
raise ValueError("Currently kwargs should be empty when network is a GraphCell. ")
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level in ["O0", "O3"]:
|
||||
self._keep_bn_fp32 = False
|
||||
|
@ -594,6 +606,8 @@ class Model:
|
|||
>>> model.train(2, dataset)
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
|
||||
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
|
||||
Validator.check_is_int(sink_size)
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
if dataset_size == 0:
|
||||
|
@ -718,9 +732,12 @@ class Model:
|
|||
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True:
|
||||
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
|
|
|
@ -38,6 +38,7 @@ from mindspore._checkparam import check_input_data, Validator
|
|||
from mindspore.compression.export import quant_export
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||
from .._c_expression import load_mindir
|
||||
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||
|
@ -229,6 +230,49 @@ def _check_param_prefix(filter_prefix, param_name):
|
|||
return False
|
||||
|
||||
|
||||
def load(file_name):
|
||||
"""
|
||||
Load MindIR.
|
||||
|
||||
The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use
|
||||
of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
|
||||
|
||||
Args:
|
||||
file_name (str): MindIR file name.
|
||||
|
||||
Returns:
|
||||
Object, a compiled graph that can executed by `GraphCell`.
|
||||
|
||||
Raises:
|
||||
ValueError: MindIR file is incorrect.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train import export, load
|
||||
>>>
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3)
|
||||
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
>>> export(net, input, file_name="net", file_format="MINDIR")
|
||||
>>> graph = load("net.mindir")
|
||||
>>> net = nn.GraphCell(graph)
|
||||
>>> output = net(input)
|
||||
"""
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("The file name must be string.")
|
||||
if not os.path.exists(file_name):
|
||||
raise ValueError("The file is not exist.")
|
||||
if not file_name.endswith(".mindir"):
|
||||
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
|
||||
|
||||
logger.info("Execute the process of loading mindir.")
|
||||
graph = load_mindir(file_name)
|
||||
if graph is None:
|
||||
raise RuntimeError("Load MindIR failed.")
|
||||
return graph
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
|
||||
"""
|
||||
Loads checkpoint info from a specified file.
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.common.initializer import TruncatedNormal
|
|||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
|
||||
def weight_variable():
|
||||
|
@ -112,3 +112,26 @@ def test_export_lenet_grad_mindir():
|
|||
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
|
||||
verify_name = "lenet_grad.mindir"
|
||||
assert os.path.exists(verify_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_mindir_and_run():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
|
||||
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
outputs0 = network(inputs0)
|
||||
|
||||
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
|
||||
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
|
||||
mindir_name = "test_lenet_load.mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(inputs0)
|
||||
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
|
Loading…
Reference in New Issue