forked from mindspore-Ecosystem/mindspore
!21411 mindir: support control flow
Merge pull request !21411 from lanzhineng/mindir_control_flow
This commit is contained in:
commit
4d71b3bd76
|
@ -121,6 +121,28 @@ using CompileGraphs = compile::CompileGraphs;
|
||||||
using abstract::AnalysisResult;
|
using abstract::AnalysisResult;
|
||||||
using mindspore::abstract::AnalysisContextPtr;
|
using mindspore::abstract::AnalysisContextPtr;
|
||||||
|
|
||||||
|
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
|
||||||
|
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
||||||
|
// Process partial("DeadNode",args) when the graph is loaded.
|
||||||
|
auto operatorPtr = node->cast<CNodePtr>()->input(0);
|
||||||
|
// Set abstract of switch(c,f,t) to null
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
|
||||||
|
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||||
|
node->set_abstract(nullptr);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Set abstract of switch(c,f,t)() to null
|
||||||
|
prim = GetCNodePrimitive(operatorPtr);
|
||||||
|
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||||
|
node->set_abstract(nullptr);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Previous inferred value
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||||
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
||||||
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
||||||
|
@ -133,10 +155,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
||||||
for (auto &node : manager->all_nodes()) {
|
for (auto &node : manager->all_nodes()) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||||
// Keep previous inferred value for CNode if is loaded from MindIR.
|
|
||||||
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
// AbstractFunction has context,but contexts in cache have been cleaned.
|
||||||
|
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
|
||||||
|
node->set_abstract(nullptr);
|
||||||
|
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle previous inferred value for CNode if is loaded from MindIR
|
||||||
|
if (res->is_load() && ResetCNodeFromLoad(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
||||||
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
||||||
node->set_abstract(nullptr);
|
node->set_abstract(nullptr);
|
||||||
|
@ -200,6 +231,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
|
||||||
if (graph->has_attr("is_load")) {
|
if (graph->has_attr("is_load")) {
|
||||||
loaded_graph = graph;
|
loaded_graph = graph;
|
||||||
loaded_graph_num += 1;
|
loaded_graph_num += 1;
|
||||||
|
res->set_is_load(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (loaded_graph_num == 0) {
|
if (loaded_graph_num == 0) {
|
||||||
|
@ -218,6 +250,8 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
|
||||||
FuncGraphPtr root_graph = *(manager->roots().begin());
|
FuncGraphPtr root_graph = *(manager->roots().begin());
|
||||||
auto root_inputs = root_graph->get_inputs();
|
auto root_inputs = root_graph->get_inputs();
|
||||||
auto loaded_inputs = loaded_graph->get_inputs();
|
auto loaded_inputs = loaded_graph->get_inputs();
|
||||||
|
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
|
||||||
|
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
|
||||||
|
|
||||||
size_t root_inputs_num = root_inputs.size();
|
size_t root_inputs_num = root_inputs.size();
|
||||||
size_t loaded_inputs_num = loaded_inputs.size();
|
size_t loaded_inputs_num = loaded_inputs.size();
|
||||||
|
@ -229,10 +263,18 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
|
||||||
auto root_input = root_inputs[index];
|
auto root_input = root_inputs[index];
|
||||||
auto loaded_input = loaded_inputs[index];
|
auto loaded_input = loaded_inputs[index];
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
|
||||||
|
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
|
||||||
|
MS_LOG(DEBUG) << "root_input abstract[" << index
|
||||||
|
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
|
||||||
|
MS_LOG(DEBUG) << "loaded_input abstract [" << index
|
||||||
|
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
|
||||||
|
|
||||||
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
||||||
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
||||||
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
||||||
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(root_shape);
|
MS_EXCEPTION_IF_NULL(root_shape);
|
||||||
MS_EXCEPTION_IF_NULL(loaded_shape);
|
MS_EXCEPTION_IF_NULL(loaded_shape);
|
||||||
MS_EXCEPTION_IF_NULL(root_type);
|
MS_EXCEPTION_IF_NULL(root_type);
|
||||||
|
@ -454,6 +496,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||||
}
|
}
|
||||||
// Analyze
|
// Analyze
|
||||||
AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
|
AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
|
||||||
|
|
||||||
// The top graph may be replaced by infer, update the top graph when the infer is done
|
// The top graph may be replaced by infer, update the top graph when the infer is done
|
||||||
parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
|
parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,8 @@ class Resource : public ResourceBase {
|
||||||
gpu_loopsink_flag_ = flag;
|
gpu_loopsink_flag_ = flag;
|
||||||
gpu_loopsink_size_ = size;
|
gpu_loopsink_size_ = size;
|
||||||
}
|
}
|
||||||
|
void set_is_load(bool flag) { is_load_ = flag; }
|
||||||
|
bool is_load() { return is_load_; }
|
||||||
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
|
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
|
||||||
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
|
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
|
||||||
// Reclaim resource and clear the cache.
|
// Reclaim resource and clear the cache.
|
||||||
|
@ -93,6 +95,8 @@ class Resource : public ResourceBase {
|
||||||
py::object input_;
|
py::object input_;
|
||||||
bool is_cleaned_;
|
bool is_cleaned_;
|
||||||
bool gpu_loopsink_flag_{false};
|
bool gpu_loopsink_flag_{false};
|
||||||
|
// The func_graph_ is loaded from mindir
|
||||||
|
bool is_load_{false};
|
||||||
int64_t gpu_loopsink_size_{1};
|
int64_t gpu_loopsink_size_{1};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -138,6 +138,7 @@ class IrExportBuilder {
|
||||||
mind_ir::NodeProto *last_node_{nullptr};
|
mind_ir::NodeProto *last_node_{nullptr};
|
||||||
std::list<FuncGraphPtr> todo_;
|
std::list<FuncGraphPtr> todo_;
|
||||||
std::map<AnfNodePtr, size_t> node_index_map_;
|
std::map<AnfNodePtr, size_t> node_index_map_;
|
||||||
|
std::set<std::string> nodeName_;
|
||||||
size_t node_index_{0};
|
size_t node_index_{0};
|
||||||
size_t shape_index_{0};
|
size_t shape_index_{0};
|
||||||
};
|
};
|
||||||
|
@ -145,16 +146,7 @@ class IrExportBuilder {
|
||||||
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
||||||
|
|
||||||
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
||||||
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
(void)GetDumpProto(func_graph);
|
||||||
MS_LOG(EXCEPTION) << "Input params is null.";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Export model info
|
|
||||||
builder_->BuildModelInfo();
|
|
||||||
|
|
||||||
// Export model and return string
|
|
||||||
builder_->BuildModel(func_graph);
|
|
||||||
|
|
||||||
return builder_->GetProtoString(func_graph);
|
return builder_->GetProtoString(func_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,7 +160,6 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo
|
||||||
|
|
||||||
// Export model and return string
|
// Export model and return string
|
||||||
builder_->BuildModel(func_graph, save_tensor_data);
|
builder_->BuildModel(func_graph, save_tensor_data);
|
||||||
|
|
||||||
return builder_->Model();
|
return builder_->Model();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,16 +182,34 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
||||||
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
||||||
ResetNodeIndex();
|
ResetNodeIndex();
|
||||||
todo_.clear();
|
todo_.clear();
|
||||||
todo_.push_back(func_graph);
|
nodeName_.clear();
|
||||||
|
// Build the main funcGraph
|
||||||
|
nodeName_.insert(func_graph->ToString());
|
||||||
|
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
|
||||||
|
std::set<FuncGraphPtr> graphVisited;
|
||||||
|
graphVisited.insert(func_graph);
|
||||||
while (!todo_.empty()) {
|
while (!todo_.empty()) {
|
||||||
FuncGraphPtr fg = todo_.back();
|
FuncGraphPtr fg = todo_.back();
|
||||||
todo_.pop_back();
|
todo_.pop_back();
|
||||||
BuildFuncGraph(fg, graph_proto, save_tensor_data);
|
if (graphVisited.count(fg) > 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (nodeName_.count(fg->ToString()) > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
|
||||||
|
}
|
||||||
|
nodeName_.insert(fg->ToString());
|
||||||
|
graphVisited.insert(fg);
|
||||||
|
auto graph = model_.add_functions();
|
||||||
|
BuildFuncGraph(fg, graph, save_tensor_data);
|
||||||
}
|
}
|
||||||
|
// Release resource
|
||||||
|
nodeName_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||||
bool save_tensor_data) {
|
bool save_tensor_data) {
|
||||||
|
// Export funcGraph name.
|
||||||
|
graph_proto->set_name(func_graph->ToString());
|
||||||
// Export parameters
|
// Export parameters
|
||||||
// 1. parameters should be mapped to ValueInfoProto
|
// 1. parameters should be mapped to ValueInfoProto
|
||||||
// 2. parameters with default value should be mapped to Initializer
|
// 2. parameters with default value should be mapped to Initializer
|
||||||
|
@ -232,6 +241,10 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
||||||
input_proto->set_name(param_name);
|
input_proto->set_name(param_name);
|
||||||
SetValueInfoProto(param, input_proto);
|
SetValueInfoProto(param, input_proto);
|
||||||
}
|
}
|
||||||
|
if (nodeName_.count(param_name) > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
|
||||||
|
}
|
||||||
|
nodeName_.insert(param_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,9 +396,13 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
||||||
} else if (IsValueNode<FuncGraph>(node)) {
|
} else if (IsValueNode<FuncGraph>(node)) {
|
||||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
todo_.push_back(fg);
|
todo_.push_back(fg);
|
||||||
type_name = fg->ToString();
|
type_name = "REF::" + fg->ToString();
|
||||||
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
|
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||||
type_name = node->ToString();
|
auto nodeName = GetUniqueNodeName(node);
|
||||||
|
type_name = "REF::" + nodeName;
|
||||||
|
if (nodeName_.count(nodeName) == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
|
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
|
||||||
}
|
}
|
||||||
|
@ -424,6 +441,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
||||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
|
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
|
||||||
tensor_proto->add_dims(1);
|
tensor_proto->add_dims(1);
|
||||||
}
|
}
|
||||||
|
} else if (type->isa<Function>()) {
|
||||||
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH);
|
||||||
|
*seq_string += type->type_name() + ",";
|
||||||
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||||
*seq_string += type->type_name() + ",";
|
*seq_string += type->type_name() + ",";
|
||||||
} else {
|
} else {
|
||||||
|
@ -468,6 +488,10 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
||||||
// Build cnode
|
// Build cnode
|
||||||
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
||||||
std::string output_name = GetUniqueNodeName(node);
|
std::string output_name = GetUniqueNodeName(node);
|
||||||
|
if (nodeName_.count(output_name) > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name;
|
||||||
|
}
|
||||||
|
nodeName_.insert(output_name);
|
||||||
node_proto->add_output(output_name);
|
node_proto->add_output(output_name);
|
||||||
node_proto->set_name(output_name);
|
node_proto->set_name(output_name);
|
||||||
node_proto->set_domain(node->fullname_with_scope());
|
node_proto->set_domain(node->fullname_with_scope());
|
||||||
|
@ -475,7 +499,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
||||||
std::string type_name = GetOpTypeName(op);
|
std::string type_name = GetOpTypeName(op);
|
||||||
node_proto->set_op_type(type_name);
|
node_proto->set_op_type(type_name);
|
||||||
last_node_ = node_proto;
|
last_node_ = node_proto;
|
||||||
|
// Maybe Tensor or Function or nullptr
|
||||||
SetShapeToNodeProto(node, node_proto);
|
SetShapeToNodeProto(node, node_proto);
|
||||||
|
|
||||||
(void)std::for_each(input_names.begin(), input_names.end(),
|
(void)std::for_each(input_names.begin(), input_names.end(),
|
||||||
[&node_proto](const string &name) { node_proto->add_input(name); });
|
[&node_proto](const string &name) { node_proto->add_input(name); });
|
||||||
|
|
||||||
|
@ -490,13 +516,17 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
||||||
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
|
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
|
||||||
SetValueToAttributeProto(attr_value, attr_proto);
|
SetValueToAttributeProto(attr_value, attr_proto);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
||||||
std::string node_name = GetUniqueNodeName(node);
|
std::string node_name = GetUniqueNodeName(node);
|
||||||
|
// FuncGraph will be added to functions and the input name is the function name.
|
||||||
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
|
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
|
todo_.push_back(fg);
|
||||||
|
return fg->ToString();
|
||||||
|
}
|
||||||
if (node->isa<ValueNode>()) {
|
if (node->isa<ValueNode>()) {
|
||||||
// When node input is a ValueNode, need to create a Constant Node
|
// When node input is a ValueNode, need to create a Constant Node
|
||||||
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
@ -539,7 +569,12 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
||||||
if ((node != nullptr) && (node->func_graph() != nullptr)) {
|
if ((node != nullptr) && (node->func_graph() != nullptr)) {
|
||||||
node_name = node->func_graph()->ToString() + ":";
|
node_name = node->func_graph()->ToString() + ":";
|
||||||
}
|
}
|
||||||
node_name += node->ToString();
|
if (node->isa<ValueNode>()) {
|
||||||
|
// Needn't value
|
||||||
|
node_name += node->AnfNode::ToString();
|
||||||
|
} else {
|
||||||
|
node_name += node->ToString();
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << "GetNodeName: " << node_name;
|
MS_LOG(DEBUG) << "GetNodeName: " << node_name;
|
||||||
return node_name;
|
return node_name;
|
||||||
}
|
}
|
||||||
|
|
|
@ -635,14 +635,12 @@ bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_
|
||||||
const mind_ir::AttributeProto &attr_proto) {
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
if (ref_attr_name.find("UMonad") != std::string::npos) {
|
if (ref_attr_name.find("UMonad") != std::string::npos) {
|
||||||
const ValuePtr kUMonad = std::make_shared<UMonad>();
|
|
||||||
auto monad_abs = kUMonad->ToAbstract();
|
auto monad_abs = kUMonad->ToAbstract();
|
||||||
auto new_value_node = NewValueNode(kUMonad);
|
auto new_value_node = NewValueNode(kUMonad);
|
||||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||||
new_value_node->set_abstract(monad_abs);
|
new_value_node->set_abstract(monad_abs);
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
} else if (ref_attr_name.find("IOMonad") != std::string::npos) {
|
} else if (ref_attr_name.find("IOMonad") != std::string::npos) {
|
||||||
const ValuePtr kIOMonad = std::make_shared<IOMonad>();
|
|
||||||
auto monad_abs = kIOMonad->ToAbstract();
|
auto monad_abs = kIOMonad->ToAbstract();
|
||||||
auto new_value_node = NewValueNode(kIOMonad);
|
auto new_value_node = NewValueNode(kIOMonad);
|
||||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||||
|
@ -768,17 +766,22 @@ std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::Get
|
||||||
return kv;
|
return kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
|
||||||
const mind_ir::NodeProto &node_proto) {
|
const std::string kOperatorTypeFlag = std::string("REF::");
|
||||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
|
||||||
if (!node_proto.has_op_type()) {
|
|
||||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
const std::string &node_name = node_proto.output(0);
|
|
||||||
const std::string &fullname_with_scope = node_proto.domain();
|
|
||||||
const std::string &node_type = node_proto.op_type();
|
const std::string &node_type = node_proto.op_type();
|
||||||
|
MS_LOG(DEBUG) << "Process Operator :" << node_type;
|
||||||
|
// Operator maybe CNode,FuncGraph or Parameter.
|
||||||
|
|
||||||
|
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
|
||||||
|
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
|
||||||
|
if (it != anfnode_build_map_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator is primitive.
|
||||||
std::shared_ptr<Primitive> prim;
|
std::shared_ptr<Primitive> prim;
|
||||||
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||||
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
||||||
|
@ -794,51 +797,65 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||||
|
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||||
|
// CNode abstract
|
||||||
|
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prim->set_attr("is_load", MakeValue(true));
|
||||||
|
return std::make_shared<ValueNode>(prim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set CNode abstract.
|
||||||
|
void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
|
||||||
|
const std::string &node_type = node_proto.op_type();
|
||||||
|
// Handle control flow operator.
|
||||||
|
auto operatorPtr = cnode_ptr->input(0);
|
||||||
|
// Set abstract of switch(c,f,t),switchLayer(c,tup) and
|
||||||
|
// partial(func,args) to null
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
|
||||||
|
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
|
||||||
|
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
|
||||||
|
cnode_ptr->set_abstract(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Set abstract of switch(c,f,t)() to null
|
||||||
|
prim = GetCNodePrimitive(operatorPtr);
|
||||||
|
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||||
|
cnode_ptr->set_abstract(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
||||||
string shape_ref_attr_name;
|
string shape_ref_attr_name;
|
||||||
|
|
||||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||||
kv = GetAbstractForCNode(attr_proto);
|
kv = GetAbstractForCNode(attr_proto);
|
||||||
continue;
|
break;
|
||||||
}
|
|
||||||
|
|
||||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
|
||||||
MS_LOG(ERROR) << "Get CNode attr failed!";
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> inputs;
|
// Because there is not context in unit test,
|
||||||
inputs.clear();
|
// abstract->broaden() is replaced by abstract->set_value(kAnyValue).
|
||||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
|
||||||
const std::string &input_name = node_proto.input(i);
|
|
||||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
|
||||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
inputs.push_back(anfnode_build_map_[input_name]);
|
|
||||||
}
|
|
||||||
prim->set_attr("is_load", MakeValue(true));
|
|
||||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
||||||
|
|
||||||
if (kv.size() == 0) {
|
if (kv.size() == 0) {
|
||||||
if (node_type == "UpdateState") {
|
if (node_type == "UpdateState") {
|
||||||
const ValuePtr kUMonad = std::make_shared<UMonad>();
|
cnode_ptr->set_abstract(kUMonad->ToAbstract());
|
||||||
auto monad_abs = kUMonad->ToAbstract();
|
|
||||||
cnode_ptr->set_abstract(monad_abs);
|
|
||||||
} else if (node_type == "Depend") {
|
} else if (node_type == "Depend") {
|
||||||
const ValuePtr kBool = std::make_shared<BoolImm>(true);
|
|
||||||
cnode_ptr->set_abstract(kBool->ToAbstract());
|
cnode_ptr->set_abstract(kBool->ToAbstract());
|
||||||
} else {
|
} else {
|
||||||
AbstractBasePtrList elem;
|
AbstractBasePtrList elem;
|
||||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||||
auto abs = cnode_ptr->input(index)->abstract();
|
auto abs = cnode_ptr->input(index)->abstract();
|
||||||
if (abs != nullptr) {
|
if (abs != nullptr) {
|
||||||
|
abs->set_value(kAnyValue);
|
||||||
elem.push_back(abs);
|
elem.push_back(abs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -848,22 +865,56 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
||||||
}
|
}
|
||||||
} else if (kv.size() == 1) {
|
} else if (kv.size() == 1) {
|
||||||
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
||||||
cnode_ptr->set_abstract(iter->second);
|
if (iter->second != nullptr) {
|
||||||
|
iter->second->set_value(kAnyValue);
|
||||||
|
cnode_ptr->set_abstract(iter->second);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||||
if (abstract == nullptr) {
|
if (abstract == nullptr) {
|
||||||
|
cnode_ptr->set_abstract(nullptr);
|
||||||
MS_LOG(ERROR) << "Node's attribute is nullptr.";
|
MS_LOG(ERROR) << "Node's attribute is nullptr.";
|
||||||
|
} else {
|
||||||
|
abstract->set_value(kAnyValue);
|
||||||
|
cnode_ptr->set_abstract(abstract);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
|
const mind_ir::NodeProto &node_proto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
if (!node_proto.has_op_type()) {
|
||||||
|
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const std::string &node_name = node_proto.output(0);
|
||||||
|
MS_LOG(DEBUG) << "Process CNode: " << node_name;
|
||||||
|
// Build inputs.
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
inputs.push_back(BuildOperatorNode(node_proto));
|
||||||
|
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||||
|
const std::string &input_name = node_proto.input(i);
|
||||||
|
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||||
|
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
cnode_ptr->set_abstract(abstract);
|
inputs.push_back(anfnode_build_map_[input_name]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||||
|
SetCNodeAbastract(node_proto, cnode_ptr);
|
||||||
|
|
||||||
|
const std::string &fullname_with_scope = node_proto.domain();
|
||||||
string debug_info_name = ParseCNodeName(node_name);
|
string debug_info_name = ParseCNodeName(node_name);
|
||||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||||
cnode_ptr->set_debug_info(debug_info_ptr);
|
cnode_ptr->set_debug_info(debug_info_ptr);
|
||||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||||
cnode_ptr->set_load_flag(true);
|
cnode_ptr->set_load_flag(true);
|
||||||
|
if (anfnode_build_map_.count(node_name) > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name;
|
||||||
|
}
|
||||||
anfnode_build_map_[node_name] = cnode_ptr;
|
anfnode_build_map_[node_name] = cnode_ptr;
|
||||||
return cnode_ptr;
|
return cnode_ptr;
|
||||||
}
|
}
|
||||||
|
@ -991,11 +1042,41 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
||||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||||
}
|
}
|
||||||
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
||||||
|
|
||||||
|
// Forward declare FuncGraph name
|
||||||
|
// Compatible with the previous proto.
|
||||||
|
if (graphBuild.has_name()) {
|
||||||
|
anfnode_build_map_[graphBuild.name()] = std::make_shared<ValueNode>(dstGraph);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < model_proto.functions_size(); ++i) {
|
||||||
|
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||||
|
const auto &graph_proto = model_proto.functions(i);
|
||||||
|
if (!graph_proto.has_name()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
|
||||||
|
}
|
||||||
|
if (anfnode_build_map_.count(graph_proto.name()) > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
|
||||||
|
}
|
||||||
|
anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parser the proto.
|
||||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
|
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
|
||||||
|
for (int i = 0; i < model_proto.functions_size(); ++i) {
|
||||||
|
const auto &graph_proto = model_proto.functions(i);
|
||||||
|
FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
|
||||||
|
if (!BuildFuncGraph(graph, graph_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
|
||||||
|
}
|
||||||
|
// Release resource
|
||||||
|
anfnode_build_map_.clear();
|
||||||
return dstGraph;
|
return dstGraph;
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -62,6 +62,8 @@ class MSANFModelParser {
|
||||||
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
||||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||||
|
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
|
||||||
|
void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
|
||||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||||
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
|
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
|
||||||
|
|
|
@ -23,6 +23,9 @@ message AttributeProto {
|
||||||
TENSOR = 17;
|
TENSOR = 17;
|
||||||
GRAPH = 18;
|
GRAPH = 18;
|
||||||
TENSORS = 19;
|
TENSORS = 19;
|
||||||
|
TUPLE = 20; // tuple
|
||||||
|
LIST = 21; // list
|
||||||
|
DICT = 22; // dictionary
|
||||||
}
|
}
|
||||||
optional string name = 1;
|
optional string name = 1;
|
||||||
optional float f = 2;
|
optional float f = 2;
|
||||||
|
@ -40,6 +43,8 @@ message AttributeProto {
|
||||||
optional string doc_string = 14;
|
optional string doc_string = 14;
|
||||||
optional string ref_attr_name = 15;
|
optional string ref_attr_name = 15;
|
||||||
optional AttributeType type = 16;
|
optional AttributeType type = 16;
|
||||||
|
repeated AttributeProto values = 17; // tuple, list,dict of value
|
||||||
|
optional AttributeType type_val = 18; // type type info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,6 +75,7 @@ message ModelProto {
|
||||||
optional string model_version = 5;
|
optional string model_version = 5;
|
||||||
optional string doc_string = 6;
|
optional string doc_string = 6;
|
||||||
optional GraphProto graph = 7;
|
optional GraphProto graph = 7;
|
||||||
|
repeated GraphProto functions = 8; // all the graphs without the main graph.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.train.serialization import export, load
|
||||||
|
|
||||||
|
|
||||||
|
def weight_variable():
|
||||||
|
return TruncatedNormal(0.02)
|
||||||
|
|
||||||
|
|
||||||
|
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||||
|
weight = weight_variable()
|
||||||
|
return nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||||
|
|
||||||
|
|
||||||
|
def fc_with_initialize(input_channels, out_channels):
|
||||||
|
weight = weight_variable()
|
||||||
|
bias = weight_variable()
|
||||||
|
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class LeNet5(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(LeNet5, self).__init__()
|
||||||
|
self.batch_size = 32
|
||||||
|
self.conv1 = conv(1, 6, 5)
|
||||||
|
self.conv2 = conv(6, 16, 5)
|
||||||
|
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||||
|
self.fc2 = fc_with_initialize(120, 84)
|
||||||
|
self.fc3 = fc_with_initialize(84, 10)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.reshape(x, (self.batch_size, -1))
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WithLossCell(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||||
|
self.loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, label):
|
||||||
|
predict = self.network(x)
|
||||||
|
return self.loss(predict, label)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOneStepCell(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_train()
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.grad = C.GradOperation(get_by_list=True)
|
||||||
|
|
||||||
|
def construct(self, x, label):
|
||||||
|
weights = self.weights
|
||||||
|
grads = self.grad(self.network, weights)(x, label)
|
||||||
|
return self.optimizer(grads)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleIfNet(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
x += 1
|
||||||
|
if x < y:
|
||||||
|
y += x
|
||||||
|
else:
|
||||||
|
y -= x
|
||||||
|
y += 5
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_export_lenet_grad_mindir():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
network = LeNet5()
|
||||||
|
network.set_train()
|
||||||
|
predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
label = Tensor(np.zeros([32, 10]).astype(np.float32))
|
||||||
|
net = TrainOneStepCell(WithLossCell(network))
|
||||||
|
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
|
||||||
|
verify_name = "lenet_grad.mindir"
|
||||||
|
assert os.path.exists(verify_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_load_mindir_and_run():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
network = LeNet5()
|
||||||
|
network.set_train()
|
||||||
|
|
||||||
|
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
outputs0 = network(inputs0)
|
||||||
|
|
||||||
|
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
|
||||||
|
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
|
||||||
|
mindir_name = "test_lenet_load.mindir"
|
||||||
|
assert os.path.exists(mindir_name)
|
||||||
|
|
||||||
|
graph = load(mindir_name)
|
||||||
|
loaded_net = nn.GraphCell(graph)
|
||||||
|
outputs_after_load = loaded_net(inputs0)
|
||||||
|
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_single_if():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path="./ifir")
|
||||||
|
network = SingleIfNet()
|
||||||
|
|
||||||
|
x = Tensor(np.array([1]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([2]).astype(np.float32))
|
||||||
|
origin_out = network(x, y)
|
||||||
|
|
||||||
|
file_name = "if_net"
|
||||||
|
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||||
|
mindir_name = file_name + ".mindir"
|
||||||
|
assert os.path.exists(mindir_name)
|
||||||
|
|
||||||
|
graph = load(mindir_name)
|
||||||
|
loaded_net = nn.GraphCell(graph)
|
||||||
|
outputs_after_load = loaded_net(x, y)
|
||||||
|
assert origin_out == outputs_after_load
|
Loading…
Reference in New Issue