|
|
|
@ -0,0 +1,631 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
|
*
|
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
*
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
|
|
#include "ir/param_value_py.h"
|
|
|
|
|
#include "debug/anf_ir_utils.h"
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "proto/onnx.pb.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
|
|
|
|
|
using FloatPtr = std::shared_ptr<Float>;
|
|
|
|
|
using IntPtr = std::shared_ptr<Int>;
|
|
|
|
|
|
|
|
|
|
// anf type to onnx type map
|
|
|
|
|
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_type_map = {
|
|
|
|
|
{kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
|
|
|
|
|
{kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
|
|
|
|
|
{kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
|
|
|
|
|
{kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
|
|
|
|
|
{kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
|
|
|
|
|
{kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
|
|
|
|
|
{kObjectTypeString, onnx::TensorProto_DataType_STRING},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_int_map = {
|
|
|
|
|
{8, onnx::TensorProto_DataType_INT8},
|
|
|
|
|
{16, onnx::TensorProto_DataType_INT16},
|
|
|
|
|
{32, onnx::TensorProto_DataType_INT32},
|
|
|
|
|
{64, onnx::TensorProto_DataType_INT64},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_float_map = {
|
|
|
|
|
{16, onnx::TensorProto_DataType_FLOAT16},
|
|
|
|
|
{32, onnx::TensorProto_DataType_FLOAT},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Can build different builder according to format
|
|
|
|
|
class IrExportBuilder;
|
|
|
|
|
using IrExportBuilderPtr = std::shared_ptr<IrExportBuilder>;
|
|
|
|
|
|
|
|
|
|
class IrExporter {
|
|
|
|
|
public:
|
|
|
|
|
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
|
|
|
|
|
virtual ~IrExporter() = default;
|
|
|
|
|
std::string GetDumpString(const FuncGraphPtr &func_graph);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
IrExportBuilderPtr builder_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class IrExportBuilder {
|
|
|
|
|
public:
|
|
|
|
|
IrExportBuilder() = default;
|
|
|
|
|
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
|
|
|
|
|
std::string GetProtoString(const FuncGraphPtr &func_graph);
|
|
|
|
|
void BuildModelInfo();
|
|
|
|
|
void BuildModel(const FuncGraphPtr &func_graph);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
|
|
|
|
void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
|
|
|
|
void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
|
|
|
|
void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto);
|
|
|
|
|
void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto);
|
|
|
|
|
std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto);
|
|
|
|
|
|
|
|
|
|
void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto);
|
|
|
|
|
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto);
|
|
|
|
|
void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto);
|
|
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
|
|
|
|
|
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
|
|
|
|
|
onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto);
|
|
|
|
|
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
|
|
|
|
|
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
|
|
|
|
|
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
|
|
|
|
|
onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits);
|
|
|
|
|
std::string GetNodeName(const AnfNodePtr &node);
|
|
|
|
|
std::string GetUniqueNodeName(const AnfNodePtr &node);
|
|
|
|
|
std::string GetOpTypeName(const AnfNodePtr &node);
|
|
|
|
|
size_t AllocateIndex() { return ++node_index_; }
|
|
|
|
|
void ResetIndex() { node_index_ = 0; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
onnx::ModelProto model_;
|
|
|
|
|
onnx::NodeProto *last_node_;
|
|
|
|
|
std::list<FuncGraphPtr> todo_;
|
|
|
|
|
std::map<AnfNodePtr, size_t> node_index_map_;
|
|
|
|
|
size_t node_index_ = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
|
|
|
|
|
|
|
|
|
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input params is null.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Export model info
|
|
|
|
|
builder_->BuildModelInfo();
|
|
|
|
|
|
|
|
|
|
// Export model and return string
|
|
|
|
|
builder_->BuildModel(func_graph);
|
|
|
|
|
|
|
|
|
|
return builder_->GetProtoString(func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_LOG(DEBUG) << "BuildModel complete!";
|
|
|
|
|
return model_.SerializeAsString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildModelInfo() {
|
|
|
|
|
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
|
|
|
|
|
model_.set_producer_name("MindSpore");
|
|
|
|
|
model_.set_model_version(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
|
|
|
|
onnx::GraphProto *graph_proto = model_.mutable_graph();
|
|
|
|
|
graph_proto->set_name(func_graph->ToString());
|
|
|
|
|
ResetIndex();
|
|
|
|
|
todo_.clear();
|
|
|
|
|
todo_.push_back(func_graph);
|
|
|
|
|
while (!todo_.empty()) {
|
|
|
|
|
FuncGraphPtr fg = todo_.back();
|
|
|
|
|
todo_.pop_back();
|
|
|
|
|
BuildFuncGraph(fg, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
// Export parameters
|
|
|
|
|
// 1. parameters should be mapped to ValueInfoProto
|
|
|
|
|
// 2. parameters with default value should be mapped to Initializer
|
|
|
|
|
BuildParameters(func_graph, graph_proto);
|
|
|
|
|
|
|
|
|
|
// Export operator nodes(include output)
|
|
|
|
|
BuildNodes(func_graph, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
for (auto &item : func_graph->parameters()) {
|
|
|
|
|
auto param = item->cast<ParameterPtr>();
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
|
|
|
|
}
|
|
|
|
|
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
|
|
|
|
|
std::string param_name = GetUniqueNodeName(param);
|
|
|
|
|
input_proto->set_name(param_name);
|
|
|
|
|
SetValueInfoProto(param, input_proto);
|
|
|
|
|
if (!param->has_default()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Using ONNX initializer to set parameter's default value
|
|
|
|
|
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
|
|
|
|
|
initializer_proto->set_name(param_name);
|
|
|
|
|
SetParamToTensorProto(param, initializer_proto);
|
|
|
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
|
|
|
|
|
py::object obj = param_value->value();
|
|
|
|
|
py::object data = obj.attr("data");
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(data)) {
|
|
|
|
|
auto method = data.attr("asnumpy");
|
|
|
|
|
py::array npy_data = method();
|
|
|
|
|
initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast<size_t>(npy_data.nbytes()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
|
|
|
|
|
auto iter = g_data_type_map.find(type_id);
|
|
|
|
|
if (iter == g_data_type_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
|
|
|
|
|
auto iter = g_data_bits_int_map.find(bits);
|
|
|
|
|
if (iter == g_data_bits_int_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
|
|
|
|
|
auto iter = g_data_bits_float_map.find(bits);
|
|
|
|
|
if (iter == g_data_bits_float_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) {
|
|
|
|
|
if (node == nullptr || value_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
|
|
|
|
|
SetValueInfoProto(node->Type(), node->Shape(), value_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
onnx::ValueInfoProto *const value_proto) {
|
|
|
|
|
onnx::TypeProto *type_proto = value_proto->mutable_type();
|
|
|
|
|
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
|
|
|
|
auto tensor = type->cast<TensorTypePtr>();
|
|
|
|
|
auto elem_type = tensor->element();
|
|
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
|
|
|
|
|
for (const auto &dim : dims) {
|
|
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
|
|
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
|
|
|
|
}
|
|
|
|
|
} else if (type->isa<Tuple>()) {
|
|
|
|
|
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
|
|
|
|
type_proto->set_denotation(std::to_string(tup_shape->shape().size()));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name("tensor");
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
auto data = value->cast<tensor::TensorPtr>();
|
|
|
|
|
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
|
|
|
|
|
auto dtype = data->data_type();
|
|
|
|
|
auto shape = data->shape_c();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(dtype));
|
|
|
|
|
for (const auto &dim : shape) {
|
|
|
|
|
tensor_proto->add_dims(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
onnx::TensorProto *const tensor_proto) {
|
|
|
|
|
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
|
|
|
|
|
}
|
|
|
|
|
auto tensor = type->cast<TensorTypePtr>();
|
|
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id()));
|
|
|
|
|
for (const auto &dim : dims) {
|
|
|
|
|
tensor_proto->add_dims(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) {
|
|
|
|
|
if (param == nullptr || tensor_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
|
|
|
|
|
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
|
|
|
|
for (const AnfNodePtr &node : nodes) {
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == func_graph->get_return()) {
|
|
|
|
|
BuildOutput(cnode, graph_proto);
|
|
|
|
|
} else {
|
|
|
|
|
BuildCNode(cnode, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
if (node->size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr arg = node->input(1);
|
|
|
|
|
// Using make_tuple to set multi-output
|
|
|
|
|
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
|
|
|
|
|
auto tuple_node = arg->cast<CNodePtr>();
|
|
|
|
|
for (size_t i = 1; i < tuple_node->size(); i++) {
|
|
|
|
|
auto input_node = arg->cast<CNodePtr>()->input(i);
|
|
|
|
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
|
|
|
|
auto output_name = GetUniqueNodeName(tuple_node->input(i));
|
|
|
|
|
output_proto->set_name(output_name);
|
|
|
|
|
last_node_->add_output(output_name);
|
|
|
|
|
SetValueInfoProto(tuple_node->input(i), output_proto);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
|
|
|
|
std::string output_name = GetUniqueNodeName(node);
|
|
|
|
|
output_proto->set_name(output_name);
|
|
|
|
|
last_node_->add_output(output_name);
|
|
|
|
|
SetValueInfoProto(arg, output_proto);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|
|
|
|
// May be ValueNode/CNode/Parameter
|
|
|
|
|
std::string type_name = "";
|
|
|
|
|
if (IsValueNode<Primitive>(node)) {
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
|
|
|
|
|
type_name = prim->ToString();
|
|
|
|
|
} else if (IsValueNode<FuncGraph>(node)) {
|
|
|
|
|
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
|
|
|
|
todo_.push_back(fg);
|
|
|
|
|
type_name = fg->ToString();
|
|
|
|
|
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
|
|
|
|
|
type_name = node->ToString();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "ExportType: " << type_name;
|
|
|
|
|
return type_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
onnx::NodeProto *const node_proto) {
|
|
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
attr_proto->set_ref_attr_name("shape");
|
|
|
|
|
attr_proto->set_name("shape");
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
SetTensorProto(type, shape, tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
|
|
|
|
|
onnx::NodeProto *const node_proto) {
|
|
|
|
|
// Get shape of cnode
|
|
|
|
|
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index
|
|
|
|
|
// 2. some cnode doesn't has shape, such as LayerNorm
|
|
|
|
|
// 3. other cnodes have shape
|
|
|
|
|
if (node->IsApply(prim::kPrimTupleGetItem)) {
|
|
|
|
|
// Get index of tuple get_item
|
|
|
|
|
int index_pos = inputs.size() - 1;
|
|
|
|
|
if (!inputs[index_pos]->isa<ValueNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos;
|
|
|
|
|
}
|
|
|
|
|
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!value->isa<IntergerImm>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name();
|
|
|
|
|
}
|
|
|
|
|
size_t index = GetValue<int>(value);
|
|
|
|
|
|
|
|
|
|
// Get type and shape of input node
|
|
|
|
|
auto tup_type = inputs[0]->Type();
|
|
|
|
|
if (!tup_type->isa<Tuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name();
|
|
|
|
|
}
|
|
|
|
|
auto type = tup_type->cast<TuplePtr>()->elements()[index];
|
|
|
|
|
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>();
|
|
|
|
|
if (index >= tup_shape->shape().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size();
|
|
|
|
|
}
|
|
|
|
|
auto shape = tup_shape->shape()[index];
|
|
|
|
|
SetShapeToNodeProto(type, shape, node_proto);
|
|
|
|
|
} else {
|
|
|
|
|
auto type = node->Type();
|
|
|
|
|
auto shape = node->Shape();
|
|
|
|
|
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
SetShapeToNodeProto(type, shape, node_proto);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
auto inputs_size = node->size();
|
|
|
|
|
if (inputs_size < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Need to build input node before dealing with cnode
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs;
|
|
|
|
|
std::vector<string> input_names;
|
|
|
|
|
for (size_t i = 1; i < inputs_size; i++) {
|
|
|
|
|
auto input = node->input(i);
|
|
|
|
|
op_inputs.push_back(input);
|
|
|
|
|
input_names.push_back(BuildInputNode(input, graph_proto));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build cnode
|
|
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
|
|
|
|
std::string output_name = GetUniqueNodeName(node);
|
|
|
|
|
node_proto->add_output(output_name);
|
|
|
|
|
node_proto->set_name(output_name);
|
|
|
|
|
AnfNodePtr op = node->input(0);
|
|
|
|
|
std::string type_name = GetOpTypeName(op);
|
|
|
|
|
node_proto->set_op_type(type_name);
|
|
|
|
|
last_node_ = node_proto;
|
|
|
|
|
SetShapeToNodeProto(node, op_inputs, node_proto);
|
|
|
|
|
(void)std::for_each(input_names.begin(), input_names.end(),
|
|
|
|
|
[&node_proto](const string &name) { node_proto->add_input(name); });
|
|
|
|
|
|
|
|
|
|
// Add primitive attrs
|
|
|
|
|
if (IsValueNode<Primitive>(op)) {
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(op);
|
|
|
|
|
for (auto attr : prim->attrs()) {
|
|
|
|
|
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
|
|
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
attr_proto->set_name(attr.first);
|
|
|
|
|
SetValueToAttributeProto(attr.second, attr_proto);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) {
|
|
|
|
|
std::string node_name = GetUniqueNodeName(node);
|
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
|
|
|
// When node input is a ValueNode, need to create a Constant Node
|
|
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
|
|
|
|
node_proto->add_output(node_name);
|
|
|
|
|
SetAttributeProto(node, node_proto);
|
|
|
|
|
}
|
|
|
|
|
return node_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
|
|
|
|
|
// Naming anfnode
|
|
|
|
|
// 1. parameter is unique in one func_graph
|
|
|
|
|
// 2. cnode and valuenode may be reduplicative, so add index to identify.
|
|
|
|
|
std::string node_name = "";
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
node_name = GetNodeName(node);
|
|
|
|
|
} else if (node->isa<CNode>() || node->isa<ValueNode>()) {
|
|
|
|
|
auto iter = node_index_map_.find(node);
|
|
|
|
|
if (iter != node_index_map_.end()) {
|
|
|
|
|
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
|
|
|
|
|
} else {
|
|
|
|
|
auto node_idx = AllocateIndex();
|
|
|
|
|
node_index_map_[node] = node_idx;
|
|
|
|
|
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Node name: " << node_name;
|
|
|
|
|
return node_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
|
|
|
|
std::string node_name = "";
|
|
|
|
|
if ((node != nullptr) && (node->func_graph() != nullptr)) {
|
|
|
|
|
node_name = node->func_graph()->ToString() + ":";
|
|
|
|
|
}
|
|
|
|
|
node_name += node->ToString();
|
|
|
|
|
MS_LOG(DEBUG) << "GetNodeName: " << node_name;
|
|
|
|
|
return node_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) {
|
|
|
|
|
if (node == nullptr || node_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
auto value = node->cast<ValueNodePtr>()->value();
|
|
|
|
|
node_proto->set_op_type("Constant");
|
|
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
attr_proto->set_name("value");
|
|
|
|
|
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
|
|
|
|
|
SetValueToAttributeProto(value, attr_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name("type");
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
if (value->isa<Int>()) {
|
|
|
|
|
auto int_value = value->cast<IntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
|
|
|
|
|
} else if (value->isa<Float>()) {
|
|
|
|
|
auto float_value = value->cast<FloatPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
|
|
|
|
|
} else if (value->isa<TensorType>()) {
|
|
|
|
|
tensor_proto->set_name("tensor");
|
|
|
|
|
auto elem_type = value->cast<TensorTypePtr>()->element();
|
|
|
|
|
if (elem_type->isa<Int>()) {
|
|
|
|
|
auto int_value = elem_type->cast<IntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
|
|
|
|
|
} else if (elem_type->isa<Float>()) {
|
|
|
|
|
auto float_value = elem_type->cast<FloatPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
if (value->isa<StringImm>() || value->isa<Scalar>()) {
|
|
|
|
|
SetScalarToAttributeProto(value, attr_proto);
|
|
|
|
|
} else if (value->isa<Number>() || value->isa<TensorType>()) {
|
|
|
|
|
SetTypeToAttributeProto(value, attr_proto);
|
|
|
|
|
} else if (value->isa<ValueSequeue>()) {
|
|
|
|
|
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto);
|
|
|
|
|
} else if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name("scalar");
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
SetScalarToProto(value, tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
|
|
|
|
|
if (value == nullptr || tensor_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
|
|
|
|
|
}
|
|
|
|
|
if (value->isa<StringImm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
|
|
|
|
|
tensor_proto->add_string_data(GetValue<std::string>(value));
|
|
|
|
|
} else if (value->isa<BoolImm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
|
|
|
|
|
tensor_proto->add_int32_data(GetValue<bool>(value));
|
|
|
|
|
} else if (value->isa<Int8Imm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8);
|
|
|
|
|
tensor_proto->add_int32_data(value->cast<Int8ImmPtr>()->value());
|
|
|
|
|
} else if (value->isa<Int16Imm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16);
|
|
|
|
|
tensor_proto->add_int32_data(value->cast<Int16ImmPtr>()->value());
|
|
|
|
|
} else if (value->isa<Int32Imm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
|
|
|
|
|
tensor_proto->add_int32_data(value->cast<Int32ImmPtr>()->value());
|
|
|
|
|
} else if (value->isa<Int64Imm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
|
|
|
|
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
|
|
|
|
|
} else if (value->isa<FloatImm>()) {
|
|
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
|
|
|
|
|
tensor_proto->add_float_data(GetValue<float>(value));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|
|
|
|
onnx::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name("scalar");
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
if (value->isa<ValueTuple>()) {
|
|
|
|
|
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
|
|
|
|
|
if (tuple_value->value().size() == 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto type_id = tuple_value->value()[0]->type()->type_id();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(type_id));
|
|
|
|
|
for (const auto &item : tuple_value->value()) {
|
|
|
|
|
SetScalarToProto(item, tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
} else if (value->isa<ValueList>()) {
|
|
|
|
|
const ValueListPtr &list_value = value->cast<ValueListPtr>();
|
|
|
|
|
if (list_value->value().size() == 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto type_id = list_value->value()[0]->type()->type_id();
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(type_id));
|
|
|
|
|
for (const auto &item : list_value->value()) {
|
|
|
|
|
SetScalarToProto(item, tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
auto builder = std::make_shared<IrExportBuilder>();
|
|
|
|
|
if (builder == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Create ir exporter failed!";
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
auto exporter = std::make_shared<IrExporter>(builder);
|
|
|
|
|
if (exporter == nullptr) {
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
return exporter->GetDumpString(func_graph);
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore
|