!2002 Add dump ir function in binary format

Merge pull request !2002 from leopz/test_dump
This commit is contained in:
mindspore-ci-bot 2020-06-12 14:58:02 +08:00 committed by Gitee
commit ef35e2d990
6 changed files with 657 additions and 6 deletions

View File

@ -1566,7 +1566,7 @@ class IrParser {
return lexer_.GetNextToken();
} else if (type == "Tuple") {
return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
} else if (type == "Array") {
} else if (type == "Tensor") {
return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr);
} else if (type == "List") {
return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);

View File

@ -118,6 +118,8 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

View File

@ -0,0 +1,631 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <algorithm>
#include <functional>
#include "ir/param_value_py.h"
#include "debug/anf_ir_utils.h"
#include "operator/ops.h"
#include "proto/onnx.pb.h"
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
using IntPtr = std::shared_ptr<Int>;
// anf type to onnx type map
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_type_map = {
{kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
{kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
{kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
{kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
{kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
{kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
{kObjectTypeString, onnx::TensorProto_DataType_STRING},
};
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_int_map = {
{8, onnx::TensorProto_DataType_INT8},
{16, onnx::TensorProto_DataType_INT16},
{32, onnx::TensorProto_DataType_INT32},
{64, onnx::TensorProto_DataType_INT64},
};
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_float_map = {
{16, onnx::TensorProto_DataType_FLOAT16},
{32, onnx::TensorProto_DataType_FLOAT},
};
// Can build different builder according to format
class IrExportBuilder;
using IrExportBuilderPtr = std::shared_ptr<IrExportBuilder>;
class IrExporter {
public:
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
virtual ~IrExporter() = default;
std::string GetDumpString(const FuncGraphPtr &func_graph);
private:
IrExportBuilderPtr builder_;
};
class IrExportBuilder {
public:
IrExportBuilder() = default;
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
std::string GetProtoString(const FuncGraphPtr &func_graph);
void BuildModelInfo();
void BuildModel(const FuncGraphPtr &func_graph);
private:
void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto);
void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto);
std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto);
void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto);
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto);
void SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto);
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto);
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits);
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
size_t AllocateIndex() { return ++node_index_; }
void ResetIndex() { node_index_ = 0; }
private:
onnx::ModelProto model_;
onnx::NodeProto *last_node_;
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_ = 0;
};
using IrExporterPtr = std::shared_ptr<IrExporter>;
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
// Export model info
builder_->BuildModelInfo();
// Export model and return string
builder_->BuildModel(func_graph);
return builder_->GetProtoString(func_graph);
}
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "BuildModel complete!";
return model_.SerializeAsString();
}
void IrExportBuilder::BuildModelInfo() {
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
model_.set_producer_name("MindSpore");
model_.set_model_version(1);
}
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
onnx::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
ResetIndex();
todo_.clear();
todo_.push_back(func_graph);
while (!todo_.empty()) {
FuncGraphPtr fg = todo_.back();
todo_.pop_back();
BuildFuncGraph(fg, graph_proto);
}
}
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
// Export parameters
// 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer
BuildParameters(func_graph, graph_proto);
// Export operator nodes(include output)
BuildNodes(func_graph, graph_proto);
}
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
for (auto &item : func_graph->parameters()) {
auto param = item->cast<ParameterPtr>();
if (param == nullptr) {
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
}
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
std::string param_name = GetUniqueNodeName(param);
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
if (!param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default";
continue;
}
// Using ONNX initializer to set parameter's default value
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_name);
SetParamToTensorProto(param, initializer_proto);
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
py::object obj = param_value->value();
py::object data = obj.attr("data");
if (py::isinstance<tensor::Tensor>(data)) {
auto method = data.attr("asnumpy");
py::array npy_data = method();
initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast<size_t>(npy_data.nbytes()));
}
}
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
auto iter = g_data_type_map.find(type_id);
if (iter == g_data_type_map.end()) {
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
}
return iter->second;
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
auto iter = g_data_bits_int_map.find(bits);
if (iter == g_data_bits_int_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
}
return iter->second;
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
auto iter = g_data_bits_float_map.find(bits);
if (iter == g_data_bits_float_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
}
return iter->second;
}
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) {
if (node == nullptr || value_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
}
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
SetValueInfoProto(node->Type(), node->Shape(), value_proto);
}
void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::ValueInfoProto *const value_proto) {
onnx::TypeProto *type_proto = value_proto->mutable_type();
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
type_proto->set_denotation(std::to_string(tup_shape->shape().size()));
} else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
}
}
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("tensor");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
auto data = value->cast<tensor::TensorPtr>();
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
tensor_proto->set_data_type(GetOnnxDataType(dtype));
for (const auto &dim : shape) {
tensor_proto->add_dims(dim);
}
}
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::TensorProto *const tensor_proto) {
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
}
auto tensor = type->cast<TensorTypePtr>();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id()));
for (const auto &dim : dims) {
tensor_proto->add_dims(dim);
}
}
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto) {
if (param == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
}
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
}
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == func_graph->get_return()) {
BuildOutput(cnode, graph_proto);
} else {
BuildCNode(cnode, graph_proto);
}
}
}
void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
if (node->size() != 2) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
// Using make_tuple to set multi-output
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
auto tuple_node = arg->cast<CNodePtr>();
for (size_t i = 1; i < tuple_node->size(); i++) {
auto input_node = arg->cast<CNodePtr>()->input(i);
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
auto output_name = GetUniqueNodeName(tuple_node->input(i));
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(tuple_node->input(i), output_proto);
}
} else {
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(arg, output_proto);
}
}
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
// May be ValueNode/CNode/Parameter
std::string type_name = "";
if (IsValueNode<Primitive>(node)) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
type_name = prim->ToString();
} else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
todo_.push_back(fg);
type_name = fg->ToString();
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
type_name = node->ToString();
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
}
MS_LOG(DEBUG) << "ExportType: " << type_name;
return type_name;
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::NodeProto *const node_proto) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_ref_attr_name("shape");
attr_proto->set_name("shape");
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetTensorProto(type, shape, tensor_proto);
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
onnx::NodeProto *const node_proto) {
// Get shape of cnode
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index
// 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape
if (node->IsApply(prim::kPrimTupleGetItem)) {
// Get index of tuple get_item
int index_pos = inputs.size() - 1;
if (!inputs[index_pos]->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos;
}
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value();
if (!value->isa<IntergerImm>()) {
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name();
}
size_t index = GetValue<int>(value);
// Get type and shape of input node
auto tup_type = inputs[0]->Type();
if (!tup_type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name();
}
auto type = tup_type->cast<TuplePtr>()->elements()[index];
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>();
if (index >= tup_shape->shape().size()) {
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size();
}
auto shape = tup_shape->shape()[index];
SetShapeToNodeProto(type, shape, node_proto);
} else {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString();
return;
}
SetShapeToNodeProto(type, shape, node_proto);
}
}
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
auto inputs_size = node->size();
if (inputs_size < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
// Need to build input node before dealing with cnode
std::vector<AnfNodePtr> op_inputs;
std::vector<string> input_names;
for (size_t i = 1; i < inputs_size; i++) {
auto input = node->input(i);
op_inputs.push_back(input);
input_names.push_back(BuildInputNode(input, graph_proto));
}
// Build cnode
onnx::NodeProto *node_proto = graph_proto->add_node();
std::string output_name = GetUniqueNodeName(node);
node_proto->add_output(output_name);
node_proto->set_name(output_name);
AnfNodePtr op = node->input(0);
std::string type_name = GetOpTypeName(op);
node_proto->set_op_type(type_name);
last_node_ = node_proto;
SetShapeToNodeProto(node, op_inputs, node_proto);
(void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); });
// Add primitive attrs
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
for (auto attr : prim->attrs()) {
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first);
SetValueToAttributeProto(attr.second, attr_proto);
}
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
}
}
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) {
std::string node_name = GetUniqueNodeName(node);
if (node->isa<ValueNode>()) {
// When node input is a ValueNode, need to create a Constant Node
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(node_name);
SetAttributeProto(node, node_proto);
}
return node_name;
}
std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
// Naming anfnode
// 1. parameter is unique in one func_graph
// 2. cnode and valuenode may be reduplicative, so add index to identify.
std::string node_name = "";
if (node->isa<Parameter>()) {
node_name = GetNodeName(node);
} else if (node->isa<CNode>() || node->isa<ValueNode>()) {
auto iter = node_index_map_.find(node);
if (iter != node_index_map_.end()) {
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
} else {
auto node_idx = AllocateIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
}
} else {
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
}
MS_LOG(DEBUG) << "Node name: " << node_name;
return node_name;
}
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
std::string node_name = "";
if ((node != nullptr) && (node->func_graph() != nullptr)) {
node_name = node->func_graph()->ToString() + ":";
}
node_name += node->ToString();
MS_LOG(DEBUG) << "GetNodeName: " << node_name;
return node_name;
}
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) {
if (node == nullptr || node_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
}
auto value = node->cast<ValueNodePtr>()->value();
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
SetValueToAttributeProto(value, attr_proto);
}
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("type");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
if (value->isa<Int>()) {
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
} else if (value->isa<Float>()) {
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
} else if (value->isa<TensorType>()) {
tensor_proto->set_name("tensor");
auto elem_type = value->cast<TensorTypePtr>()->element();
if (elem_type->isa<Int>()) {
auto int_value = elem_type->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
} else if (elem_type->isa<Float>()) {
auto float_value = elem_type->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
} else {
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
}
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
}
}
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
if (value->isa<StringImm>() || value->isa<Scalar>()) {
SetScalarToAttributeProto(value, attr_proto);
} else if (value->isa<Number>() || value->isa<TensorType>()) {
SetTypeToAttributeProto(value, attr_proto);
} else if (value->isa<ValueSequeue>()) {
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto);
} else if (value->isa<tensor::Tensor>()) {
SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
}
}
void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetScalarToProto(value, tensor_proto);
}
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
if (value == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
}
if (value->isa<StringImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
tensor_proto->add_string_data(GetValue<std::string>(value));
} else if (value->isa<BoolImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
tensor_proto->add_int32_data(GetValue<bool>(value));
} else if (value->isa<Int8Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8);
tensor_proto->add_int32_data(value->cast<Int8ImmPtr>()->value());
} else if (value->isa<Int16Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16);
tensor_proto->add_int32_data(value->cast<Int16ImmPtr>()->value());
} else if (value->isa<Int32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
tensor_proto->add_int32_data(value->cast<Int32ImmPtr>()->value());
} else if (value->isa<Int64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
} else if (value->isa<FloatImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
tensor_proto->add_float_data(GetValue<float>(value));
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
}
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
if (value->isa<ValueTuple>()) {
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
if (tuple_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
return;
}
auto type_id = tuple_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : tuple_value->value()) {
SetScalarToProto(item, tensor_proto);
}
} else if (value->isa<ValueList>()) {
const ValueListPtr &list_value = value->cast<ValueListPtr>();
if (list_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0";
return;
}
auto type_id = list_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : list_value->value()) {
SetScalarToProto(item, tensor_proto);
}
}
}
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
auto builder = std::make_shared<IrExportBuilder>();
if (builder == nullptr) {
MS_LOG(ERROR) << "Create ir exporter failed!";
return "";
}
auto exporter = std::make_shared<IrExporter>(builder);
if (exporter == nullptr) {
return "";
}
return exporter->GetDumpString(func_graph);
}
} // namespace mindspore

View File

@ -59,6 +59,7 @@ using mindspore::abstract::AbstractTuplePtr;
const char IR_TYPE_ANF[] = "anf_ir";
const char IR_TYPE_ONNX[] = "onnx_ir";
const char IR_TYPE_BINARY[] = "binary_ir";
ExecutorPyPtr ExecutorPy::executor_ = nullptr;
std::mutex ExecutorPy::instance_lock_;
@ -212,6 +213,14 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::str
return proto_str;
}
if (ir_type == IR_TYPE_BINARY) {
std::string proto_str = GetBinaryProtoString(fg_ptr);
if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "Graph proto is empty.";
}
return proto_str;
}
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
}
@ -506,7 +515,6 @@ void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource,
// when in loading anf ir mode, action `parse` do nothing
if (action.first == "parse") {
parse::PythonAdapter::SetPythonEnvFlag(true);
return;
}
@ -566,6 +574,7 @@ void Pipeline::Run() {
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
DumpIR(base_name + ".ir", graph);
// generate IR file in a heavily commented format, which can also be reloaded
if (action.first != "parse") {
ExportIR(base_name + ".dat", std::to_string(i), graph);

View File

@ -398,17 +398,18 @@ def export(net, *inputs, file_name, file_format='GEIR'):
net (Cell): MindSpore network.
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of model to export.
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'LITE' format for exported model.
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' 'LITE' and 'BINARY' format for exported model.
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
- LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite
- BINARY: Binary format for model. An intermidiate representation format for models.
"""
logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor)
supported_formats = ['GEIR', 'ONNX', 'LITE']
supported_formats = ['GEIR', 'ONNX', 'LITE', 'BINARY']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# switch network mode to infer when it is training
@ -428,6 +429,13 @@ def export(net, *inputs, file_name, file_format='GEIR'):
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'BINARY': # file_format is 'BINARY'
phase_name = 'export_binary'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'binary_ir')
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
elif file_format == 'LITE': # file_format is 'LITE'
context.set_context(save_ms_model=True, save_ms_model_path=file_name)
net(*inputs)

View File

@ -17,8 +17,9 @@
namespace mindspore {
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { return ""; }
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { return ""; }
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
} // namespace mindspore