forked from mindspore-Ecosystem/mindspore
!21803 mindir: support control flow switch layer
Merge pull request !21803 from lanzhineng/mindir_control_flow
This commit is contained in:
commit
223f500bab
|
@ -121,30 +121,6 @@ using CompileGraphs = compile::CompileGraphs;
|
||||||
using abstract::AnalysisResult;
|
using abstract::AnalysisResult;
|
||||||
using mindspore::abstract::AnalysisContextPtr;
|
using mindspore::abstract::AnalysisContextPtr;
|
||||||
|
|
||||||
// Some operators are not defined.
|
|
||||||
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
|
|
||||||
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
|
||||||
// Process partial("DeadNode",args) when the graph is loaded.
|
|
||||||
auto operatorPtr = node->cast<CNodePtr>()->input(0);
|
|
||||||
// Set abstract of switch(c,f,t) to null
|
|
||||||
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
|
|
||||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
|
|
||||||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
|
|
||||||
node->set_abstract(nullptr);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// If the operator is not a primitive, the abstract will been set to null.
|
|
||||||
// Because there are not some operators in front end, the abstract of primitive should be reserved.
|
|
||||||
if (prim == nullptr) {
|
|
||||||
node->set_abstract(nullptr);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// Previous inferred value
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||||
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
||||||
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
||||||
|
@ -156,17 +132,22 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
||||||
engine->Clear();
|
engine->Clear();
|
||||||
for (auto &node : manager->all_nodes()) {
|
for (auto &node : manager->all_nodes()) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
|
||||||
|
|
||||||
// Handle previous inferred value for CNode if is loaded from MindIR
|
// Handle previous inferred value for CNode if is loaded from MindIR
|
||||||
if (res->is_load() && ResetCNodeFromLoad(node)) {
|
if (res->is_load()) {
|
||||||
|
// If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
|
||||||
|
auto primitive = GetCNodePrimitive(node);
|
||||||
|
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
|
||||||
|
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||||
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
||||||
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
||||||
node->set_abstract(nullptr);
|
node->set_abstract(nullptr);
|
||||||
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
|
MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -275,7 +256,9 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
|
||||||
MS_EXCEPTION_IF_NULL(root_type);
|
MS_EXCEPTION_IF_NULL(root_type);
|
||||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||||
|
|
||||||
if (root_shape->shape() != loaded_shape->shape()) {
|
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
|
||||||
|
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
|
||||||
|
if (!shapeEqu) {
|
||||||
MS_EXCEPTION(ValueError) << "The " << index
|
MS_EXCEPTION(ValueError) << "The " << index
|
||||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||||
|
@ -531,8 +514,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
||||||
auto func_graph = res->func_graph();
|
auto func_graph = res->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
func_graph->DumpFuncGraph(fg_name);
|
func_graph->DumpFuncGraph(fg_name);
|
||||||
ExportIR(fg_name + ".dat", func_graph);
|
|
||||||
DumpIR(fg_name + ".ir", func_graph);
|
DumpIR(fg_name + ".ir", func_graph);
|
||||||
|
ExportIR(fg_name + ".dat", func_graph);
|
||||||
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
||||||
}
|
}
|
||||||
counter++;
|
counter++;
|
||||||
|
|
|
@ -359,7 +359,6 @@ void AnalysisEngine::Clear() {
|
||||||
root_context_ = nullptr;
|
root_context_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
||||||
// Custom Primitive with python infer_shape, infer_type
|
// Custom Primitive with python infer_shape, infer_type
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
@ -396,7 +395,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
||||||
engine->prim_py_evaluators_[prim_py] = evaluator;
|
engine->prim_py_evaluators_[prim_py] = evaluator;
|
||||||
return evaluator;
|
return evaluator;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
|
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// return a default evaluator
|
// return a default evaluator
|
||||||
|
@ -416,11 +416,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (evaluator == nullptr) {
|
if (evaluator == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
|
MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
|
||||||
}
|
}
|
||||||
return evaluator;
|
return evaluator;
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
|
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
|
||||||
MS_EXCEPTION_IF_NULL(func);
|
MS_EXCEPTION_IF_NULL(func);
|
||||||
|
@ -430,6 +429,9 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
|
||||||
}
|
}
|
||||||
auto primitive = func->prim();
|
auto primitive = func->prim();
|
||||||
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
|
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
|
||||||
|
if (evaluator == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
|
||||||
|
}
|
||||||
evaluators_[func] = evaluator;
|
evaluators_[func] = evaluator;
|
||||||
return evaluator;
|
return evaluator;
|
||||||
}
|
}
|
||||||
|
@ -1012,7 +1014,9 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
||||||
|
|
||||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
||||||
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
||||||
MS_EXCEPTION_IF_NULL(evaluator);
|
if (evaluator == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
|
||||||
|
}
|
||||||
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
||||||
MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
|
MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
|
||||||
<< evaluator->ToString();
|
<< evaluator->ToString();
|
||||||
|
|
|
@ -347,7 +347,7 @@ template <typename T>
|
||||||
AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
||||||
return FromValueInside(MakeValue(value), broaden);
|
return FromValueInside(MakeValue(value), broaden);
|
||||||
}
|
}
|
||||||
|
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
|
||||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -137,10 +137,11 @@ class IrExportBuilder {
|
||||||
mind_ir::ModelProto model_;
|
mind_ir::ModelProto model_;
|
||||||
mind_ir::NodeProto *last_node_{nullptr};
|
mind_ir::NodeProto *last_node_{nullptr};
|
||||||
std::list<FuncGraphPtr> todo_;
|
std::list<FuncGraphPtr> todo_;
|
||||||
std::map<AnfNodePtr, size_t> node_index_map_;
|
std::map<AnfNodePtr, std::string> node_index_map_;
|
||||||
std::set<std::string> nodeName_;
|
std::set<std::string> nodeName_;
|
||||||
size_t node_index_{0};
|
size_t node_index_{0};
|
||||||
size_t shape_index_{0};
|
size_t shape_index_{0};
|
||||||
|
bool top_graph{true};
|
||||||
};
|
};
|
||||||
|
|
||||||
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
||||||
|
@ -185,9 +186,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
||||||
nodeName_.clear();
|
nodeName_.clear();
|
||||||
// Build the main funcGraph
|
// Build the main funcGraph
|
||||||
nodeName_.insert(func_graph->ToString());
|
nodeName_.insert(func_graph->ToString());
|
||||||
|
top_graph = true;
|
||||||
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
|
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
|
||||||
std::set<FuncGraphPtr> graphVisited;
|
std::set<FuncGraphPtr> graphVisited;
|
||||||
graphVisited.insert(func_graph);
|
graphVisited.insert(func_graph);
|
||||||
|
top_graph = false;
|
||||||
while (!todo_.empty()) {
|
while (!todo_.empty()) {
|
||||||
FuncGraphPtr fg = todo_.back();
|
FuncGraphPtr fg = todo_.back();
|
||||||
todo_.pop_back();
|
todo_.pop_back();
|
||||||
|
@ -204,6 +207,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
||||||
}
|
}
|
||||||
// Release resource
|
// Release resource
|
||||||
nodeName_.clear();
|
nodeName_.clear();
|
||||||
|
node_index_map_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||||
|
@ -227,8 +231,8 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
||||||
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
||||||
}
|
}
|
||||||
std::string param_name = GetUniqueNodeName(param);
|
std::string param_name = GetUniqueNodeName(param);
|
||||||
if (param->has_default()) {
|
if (top_graph && param->has_default()) {
|
||||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has default.";
|
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
|
||||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||||
parameter_proto->set_name(param_name);
|
parameter_proto->set_name(param_name);
|
||||||
SetParamToTensorProto(param, parameter_proto);
|
SetParamToTensorProto(param, parameter_proto);
|
||||||
|
@ -308,7 +312,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
||||||
} else if (type->isa<Tuple>()) {
|
} else if (type->isa<Tuple>()) {
|
||||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||||
} else if (type->isa<Number>() || type->isa<String>()) {
|
} else if (type->isa<Number>() || type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||||
value_proto->set_denotation(type->type_name());
|
value_proto->set_denotation(type->type_name());
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
||||||
|
@ -541,28 +545,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
|
||||||
// Naming anfnode
|
// Naming anfnode
|
||||||
// 1. parameter is unique in one func_graph
|
// 1. parameter is unique in one func_graph
|
||||||
// 2. cnode and valuenode may be reduplicative, so add index to identify.
|
// 2. cnode and valuenode may be reduplicative, so add index to identify.
|
||||||
std::string node_name = "";
|
|
||||||
if (node->isa<Parameter>()) {
|
|
||||||
node_name = GetNodeName(node);
|
|
||||||
} else if (node->isa<CNode>()) {
|
|
||||||
auto iter = node_index_map_.find(node);
|
auto iter = node_index_map_.find(node);
|
||||||
if (iter != node_index_map_.end()) {
|
if (iter != node_index_map_.end()) {
|
||||||
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
|
return iter->second;
|
||||||
} else {
|
} else {
|
||||||
|
std::string node_name = GetNodeName(node);
|
||||||
|
while (nodeName_.count(node_name) > 0) {
|
||||||
auto node_idx = GetNodeIndex();
|
auto node_idx = GetNodeIndex();
|
||||||
node_index_map_[node] = node_idx;
|
node_name = node_name + ":" + std::to_string(node_idx);
|
||||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
|
||||||
}
|
}
|
||||||
} else if (node->isa<ValueNode>()) {
|
node_index_map_[node] = node_name;
|
||||||
auto node_idx = GetNodeIndex();
|
|
||||||
node_index_map_[node] = node_idx;
|
|
||||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "Node name: " << node_name;
|
|
||||||
return node_name;
|
return node_name;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
||||||
std::string node_name = "";
|
std::string node_name = "";
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor, nn
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.train.serialization import export, load
|
||||||
|
|
||||||
|
|
||||||
|
class CaseNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CaseNet, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(1, 1, 3)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
self.softmax = nn.Softmax()
|
||||||
|
self.layers1 = (self.relu, self.softmax)
|
||||||
|
self.layers2 = (self.conv, self.relu1)
|
||||||
|
|
||||||
|
def construct(self, x, index1, index2):
|
||||||
|
x = self.layers1[index1](x)
|
||||||
|
x = self.layers2[index2](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_mindir_switch_layer():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = CaseNet()
|
||||||
|
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||||
|
idx = Tensor(0, mstype.int32)
|
||||||
|
idx2 = Tensor(-1, mstype.int32)
|
||||||
|
|
||||||
|
file_name = "switch_layer_net"
|
||||||
|
mindir_name = file_name + ".mindir"
|
||||||
|
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
|
||||||
|
assert os.path.exists(mindir_name)
|
||||||
|
|
||||||
|
graph = load(mindir_name)
|
||||||
|
loaded_net = nn.GraphCell(graph)
|
||||||
|
outputs_after_load = loaded_net(data, idx, idx2)
|
||||||
|
relu = nn.ReLU()
|
||||||
|
true_value = relu(data)
|
||||||
|
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
|
||||||
|
assert ret
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="depend on export")
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_mindir_export():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = CaseNet()
|
||||||
|
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||||
|
idx = Tensor(0, mstype.int32)
|
||||||
|
idx2 = Tensor(-1, mstype.int32)
|
||||||
|
|
||||||
|
file_name = "switch_layer_net"
|
||||||
|
mindir_name = file_name + ".mindir"
|
||||||
|
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
|
||||||
|
assert os.path.exists(mindir_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="depend on export")
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_mindir_load():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||||
|
idx = Tensor(0, mstype.int32)
|
||||||
|
idx2 = Tensor(-1, mstype.int32)
|
||||||
|
|
||||||
|
file_name = "switch_layer_net"
|
||||||
|
mindir_name = file_name + ".mindir"
|
||||||
|
graph = load(mindir_name)
|
||||||
|
loaded_net = nn.GraphCell(graph)
|
||||||
|
outputs_after_load = loaded_net(data, idx, idx2)
|
||||||
|
relu = nn.ReLU()
|
||||||
|
true_value = relu(data)
|
||||||
|
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
|
||||||
|
assert ret
|
Loading…
Reference in New Issue