回退 'Pull Request !41904 : 昇腾众智-西安电子科技大学-ONNX converter:stpm'
This commit is contained in:
parent
c4613af9b8
commit
946853f2f2
|
@ -1237,8 +1237,6 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimMirrorPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReverseV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimTensorCopySlices(const FuncGraphPtr &, const CNodePtr &node,
|
||||
|
@ -3266,115 +3264,6 @@ void OnnxExporter::ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node,
|
|||
AddReshapeOp(transpose_node_name, output_name, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
|
||||
}
|
||||
|
||||
// MirrorPad -> Pad + Compress
|
||||
void OnnxExporter::ExportPrimMirrorPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
|
||||
auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
|
||||
const auto &input_x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
|
||||
const auto &paddings_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape())->shape();
|
||||
|
||||
size_t dim0 = paddings_shape[0];
|
||||
auto p_node = dyn_cast<ValueNode>(node->input(kTwoNum));
|
||||
auto paddings_values = *reinterpret_cast<size_t(*)[dim0][2]>(dyn_cast<tensor::Tensor>(p_node->value())->data_c());
|
||||
|
||||
auto mode = GetOpAttribute<std::string>(node, "mode");
|
||||
|
||||
std::vector<int64_t> pads_sequence, pads_symmetric;
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
pads_sequence.push_back(paddings_values[i][0]);
|
||||
}
|
||||
for (size_t j = 0; j < dim0; ++j) {
|
||||
pads_sequence.push_back(paddings_values[j][1]);
|
||||
}
|
||||
auto pads_ptr = MakeValue<std::vector<int64_t>>(pads_sequence);
|
||||
|
||||
auto pads_name = RegisterNodeWithUniqueName(NewValueNode(pads_ptr)->cast<AnfNodePtr>(), node_map_ptr);
|
||||
onnx::NodeProto *pads_node = graph_proto->add_node();
|
||||
pads_node->add_output(pads_name);
|
||||
pads_node->set_op_type("Constant");
|
||||
onnx::AttributeProto *pads_attr_proto = pads_node->add_attribute();
|
||||
pads_attr_proto->set_name("value");
|
||||
pads_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(pads_ptr, pads_attr_proto->mutable_t());
|
||||
|
||||
auto ms_pad_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
onnx::NodeProto *onnx_pad_node = graph_proto->add_node();
|
||||
onnx_pad_node->set_op_type("Pad");
|
||||
onnx_pad_node->add_output(ms_pad_node_name);
|
||||
if (mode == "SYMMETRIC") {
|
||||
for (size_t i = 0; i < (dim0 + dim0); ++i) {
|
||||
pads_symmetric.push_back(1);
|
||||
}
|
||||
auto p_s_ptr = MakeValue<std::vector<int64_t>>(pads_symmetric);
|
||||
|
||||
auto pads_symmetric_name = RegisterNodeWithUniqueName(NewValueNode(p_s_ptr)->cast<AnfNodePtr>(), node_map_ptr);
|
||||
onnx::NodeProto *pads_symmetric_node = graph_proto->add_node();
|
||||
pads_symmetric_node->add_output(pads_symmetric_name);
|
||||
pads_symmetric_node->set_op_type("Constant");
|
||||
onnx::AttributeProto *pads_symmetric_attr_proto = pads_symmetric_node->add_attribute();
|
||||
pads_symmetric_attr_proto->set_name("value");
|
||||
pads_symmetric_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(p_s_ptr, pads_symmetric_attr_proto->mutable_t());
|
||||
|
||||
auto ms_pad_symmetric_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
onnx::NodeProto *onnx_pad_symmetric_node = graph_proto->add_node();
|
||||
onnx_pad_symmetric_node->set_op_type("Pad");
|
||||
onnx_pad_symmetric_node->add_output(ms_pad_symmetric_node_name);
|
||||
onnx_pad_symmetric_node->add_input(input_x_name);
|
||||
onnx_pad_symmetric_node->add_input(pads_symmetric_name);
|
||||
|
||||
onnx_pad_node->add_input(ms_pad_symmetric_node_name);
|
||||
} else {
|
||||
onnx_pad_node->add_input(input_x_name);
|
||||
}
|
||||
onnx_pad_node->add_input(pads_name);
|
||||
|
||||
onnx::AttributeProto *mode_proto = onnx_pad_node->add_attribute();
|
||||
mode_proto->set_name("mode");
|
||||
mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
mode_proto->set_s("reflect");
|
||||
|
||||
if (mode == "SYMMETRIC") {
|
||||
auto input_symmetric_name = ms_pad_node_name;
|
||||
for (size_t k = 0; k < dim0; ++k) {
|
||||
std::vector<int64_t> condition_sequence;
|
||||
int64_t symmetric_len = pads_sequence[k] + pads_sequence[k + dim0] + input_x_shape[k] + 2;
|
||||
int64_t second_len = input_x_shape[k] + 1 + pads_sequence[k];
|
||||
for (int64_t l = 0; l < symmetric_len; ++l) {
|
||||
if (l == pads_sequence[k] || l == second_len) {
|
||||
condition_sequence.push_back(0);
|
||||
} else {
|
||||
condition_sequence.push_back(1);
|
||||
}
|
||||
}
|
||||
auto c_ptr = MakeValue<std::vector<int64_t>>(condition_sequence);
|
||||
|
||||
auto condition_sequence_n = RegisterNodeWithUniqueName(NewValueNode(c_ptr)->cast<AnfNodePtr>(), node_map_ptr);
|
||||
onnx::NodeProto *condition_sequence_node = graph_proto->add_node();
|
||||
condition_sequence_node->add_output(condition_sequence_n);
|
||||
condition_sequence_node->set_op_type("Constant");
|
||||
onnx::AttributeProto *condition_sequence_attr_proto = condition_sequence_node->add_attribute();
|
||||
condition_sequence_attr_proto->set_name("value");
|
||||
condition_sequence_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(c_ptr, condition_sequence_attr_proto->mutable_t());
|
||||
|
||||
auto compress_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
onnx::NodeProto *compress_node = graph_proto->add_node();
|
||||
compress_node->set_op_type("Compress");
|
||||
compress_node->add_output(compress_node_name);
|
||||
compress_node->add_input(input_symmetric_name);
|
||||
compress_node->add_input(condition_sequence_n);
|
||||
onnx::AttributeProto *axis_attr_proto = compress_node->add_attribute();
|
||||
axis_attr_proto->set_name("axis");
|
||||
axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
axis_attr_proto->set_i(k);
|
||||
|
||||
input_symmetric_name = compress_node_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimReverseV2(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
|
@ -3634,7 +3523,6 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimAddN, &OnnxExporter::ExportPrimAddN},
|
||||
{prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
|
||||
{prim::kPrimLstm, &OnnxExporter::ExportPrimLSTM},
|
||||
{prim::kPrimMirrorPad, &OnnxExporter::ExportPrimMirrorPad},
|
||||
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
|
||||
{prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
|
||||
{prim::kPrimStack, &OnnxExporter::ExportPrimStack},
|
||||
|
@ -3794,12 +3682,6 @@ onnx::TensorProto_DataType OnnxExporter::GetOutputType(const AnfNodePtr &node, i
|
|||
*/
|
||||
|
||||
if (output_index == -1) {
|
||||
if (IsPrimitiveCNode(unpacked, prim::kPrimMakeTuple)) {
|
||||
auto tuple = dyn_cast<Tuple>(unpacked->Type());
|
||||
auto element_type = tuple->elements()[0];
|
||||
auto tensor_type = dyn_cast<TensorType>(element_type);
|
||||
return GetOnnxDataType(tensor_type->element()->type_id());
|
||||
}
|
||||
auto tensor = dyn_cast<TensorType>(unpacked->Type());
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Expected output of node " << unpacked->ToString()
|
||||
|
@ -4061,12 +3943,6 @@ void OnnxExporter::ExportOutput(const FuncGraphPtr &, const AnfNodePtr &return_a
|
|||
auto output_name = GetNodeInputName(output, node_map_ptr, graph_proto);
|
||||
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
output_proto->set_name(output_name);
|
||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto dtype = GetOutputType(output);
|
||||
auto *type_proto = output_proto->mutable_type();
|
||||
type_proto->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type()->set_elem_type(dtype);
|
||||
continue;
|
||||
}
|
||||
SetValueInfoType(output, output_proto);
|
||||
}
|
||||
} else if (arg->isa<ValueNode>() && arg->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
|
||||
|
|
Loading…
Reference in New Issue