forked from mindspore-Ecosystem/mindspore
!30503 ONNX converter: GNMT support part two
Merge pull request !30503 from amalyshev/onnx-gnmt-part-two
This commit is contained in:
commit
fab3dcf03d
|
@ -541,6 +541,236 @@ void ClipPointsComponent(const std::string &points, const std::string &clipped,
|
|||
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
|
||||
}
|
||||
|
||||
namespace while_loop_export {
|
||||
namespace {
|
||||
const char CONTROL_PATTERN[] = "⤾";
|
||||
const char LOOP_BODY_PATTERN[] = "⥁";
|
||||
const char AFTER_LOOP_PATTERN[] = "↓";
|
||||
|
||||
const size_t LOOP_BODY_INPUT = 2;
|
||||
const size_t AFTER_LOOP_INPUT = 3;
|
||||
|
||||
bool IsSubgraphNameCorrect(const FuncGraphPtr &func_graph, const std::string &part_pattern) {
|
||||
auto name = func_graph->ToString();
|
||||
return name.find("construct") != std::string::npos && name.find(part_pattern) != std::string::npos;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const std::shared_ptr<T> GetNodeInput(const CNodePtr &node, size_t i) {
|
||||
auto input = GetRealInput(node->input(i));
|
||||
auto result = dyn_cast<T>(input);
|
||||
if (result == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get input " << i << " of node " << node->DebugString();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const std::shared_ptr<T> GetNodeInputValue(const CNodePtr &node, size_t i) {
|
||||
auto input = GetNodeInput<ValueNode>(node, i);
|
||||
auto result = dyn_cast<T>(input->value());
|
||||
if (result == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get a value from input " << i << " of node " << node->DebugString();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CNodePtr FindLoopSwitchNode(const FuncGraphPtr &control_subgraph) {
|
||||
if (!IsSubgraphNameCorrect(control_subgraph, CONTROL_PATTERN)) {
|
||||
MS_LOG(EXCEPTION) << "Expected a loop control structure";
|
||||
}
|
||||
auto lazy_call_node = GetNodeInput<CNode>(control_subgraph->get_return(), kOneNum);
|
||||
if (lazy_call_node->inputs().size() != kOneNum || !lazy_call_node->input(kZeroNum)->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Expected a lazy call node";
|
||||
}
|
||||
auto switch_node = GetNodeInput<CNode>(lazy_call_node, kZeroNum);
|
||||
if (!switch_node->IsApply(prim::kPrimSwitch)) {
|
||||
MS_LOG(EXCEPTION) << "Expected a switch node";
|
||||
}
|
||||
return switch_node;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetSubgraph(const CNodePtr &switch_node, size_t input_index, const std::string &name_pattern) {
|
||||
auto input_node = GetNodeInput<CNode>(switch_node, input_index);
|
||||
if (!input_node->IsApply(prim::kPrimPartial)) {
|
||||
MS_LOG(EXCEPTION) << "Expected a partial node";
|
||||
}
|
||||
|
||||
auto subgraph = GetNodeInputValue<FuncGraph>(input_node, kOneNum);
|
||||
if (!IsSubgraphNameCorrect(subgraph, name_pattern)) {
|
||||
MS_LOG(EXCEPTION) << "Expected a loop part: " << name_pattern;
|
||||
}
|
||||
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
// The inputs of this node are the outputs of ONNX Loop
|
||||
CNodePtr FindLoopRepeatNode(const FuncGraphPtr &loop_subgraph, const FuncGraphPtr &control_subgraph) {
|
||||
auto repeat_node = GetNodeInput<CNode>(loop_subgraph->return_node(), kOneNum);
|
||||
auto maybe_control_graph = GetNodeInputValue<FuncGraph>(repeat_node, kZeroNum);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(maybe_control_graph == control_subgraph, "Loop matching failed");
|
||||
return repeat_node;
|
||||
}
|
||||
|
||||
struct LoopConditionInfo {
|
||||
int64_t begin;
|
||||
int64_t end;
|
||||
int64_t step;
|
||||
};
|
||||
|
||||
/*
|
||||
NOTE: loop support is currently very limited, because proper condition export requires more graph surgery (copying
|
||||
condition expression before and inside Loop subgraph)
|
||||
The only while loop form supported currently is the one used in GNMT v2's Beam Search. Python example:
|
||||
i = begin
|
||||
while i < end:
|
||||
# ...
|
||||
i += step
|
||||
To enable proper support for arbitrary while loop contitions, condition calculation should be duplicated inside the
|
||||
Loop supgraph. But exporting the same ops twice with different names is not currently supported.
|
||||
*/
|
||||
LoopConditionInfo TraceLoopConditionInfo(const CNodePtr &start_node, const CNodePtr &cond_node,
|
||||
const FuncGraphPtr &control_subgraph, const CNodePtr &loop_repeat_node) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(cond_node->IsApply(prim::kPrimLess), "Expected Less node");
|
||||
|
||||
auto counter = GetNodeInput<Parameter>(cond_node, kOneNum);
|
||||
auto end_tensor = GetNodeInputValue<tensor::Tensor>(cond_node, kTwoNum);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(end_tensor->shape_c().empty(), "Expected a scalar tensor");
|
||||
auto end = *reinterpret_cast<const int32_t *>(end_tensor->data_c());
|
||||
|
||||
const auto &subgraph_args = control_subgraph->parameters();
|
||||
auto counter_input_pos = std::find(subgraph_args.begin(), subgraph_args.end(), counter) - subgraph_args.begin();
|
||||
|
||||
auto begin_tensor = GetNodeInputValue<tensor::Tensor>(start_node, 1 + counter_input_pos);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(begin_tensor->shape_c().empty(), "Expected a scalar tensor");
|
||||
auto begin = *reinterpret_cast<const int32_t *>(begin_tensor->data_c());
|
||||
|
||||
auto increment_node = GetNodeInput<CNode>(loop_repeat_node, 1 + counter_input_pos);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(increment_node->IsApply(prim::kPrimAdd), "Expected Add node");
|
||||
auto step_tensor = GetNodeInputValue<tensor::Tensor>(increment_node, kTwoNum);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(step_tensor->shape_c().empty(), "Expected a scalar tensor");
|
||||
auto step = *reinterpret_cast<const int32_t *>(step_tensor->data_c());
|
||||
|
||||
return LoopConditionInfo{begin, end, step};
|
||||
}
|
||||
|
||||
// result[i] is which control subgraph input should be taken for pos i to match the order of loop subgraph inputs
|
||||
std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph) {
|
||||
std::vector<size_t> result;
|
||||
|
||||
auto switch_node = FindLoopSwitchNode(control_subgraph);
|
||||
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
||||
const auto &control_params = control_subgraph->parameters();
|
||||
auto auxiliary_inputs_num = 2;
|
||||
for (size_t i = auxiliary_inputs_num; i < loop_partial_node->inputs().size(); ++i) {
|
||||
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
|
||||
auto control_param_pos =
|
||||
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
|
||||
result.push_back(control_param_pos);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
|
||||
std::vector<size_t> result;
|
||||
|
||||
auto switch_node = FindLoopSwitchNode(control_subgraph);
|
||||
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
||||
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
|
||||
const auto &loop_params = loop_partial_node->inputs();
|
||||
auto auxiliary_inputs_num = 2;
|
||||
for (size_t i = auxiliary_inputs_num; i < after_partial_node->inputs().size(); ++i) {
|
||||
auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
|
||||
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
|
||||
result.push_back(after_param_pos - auxiliary_inputs_num);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<bool> TraceIgnoredLoopParams(const CNodePtr &start_node, const std::vector<size_t> &loop_to_control_map) {
|
||||
auto inputs_num = start_node->inputs().size() - 1;
|
||||
std::vector<bool> result(inputs_num);
|
||||
for (size_t loop_i = 0; loop_i < inputs_num; ++loop_i) {
|
||||
auto control_i = loop_to_control_map.at(loop_i);
|
||||
const auto &input = start_node->input(control_i + 1);
|
||||
if ((input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) || HasAbstractMonad(input)) {
|
||||
result.at(loop_i) = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IsControlSubgraph(const ValuePtr &func_graph_node) {
|
||||
auto func_graph = dyn_cast<FuncGraph>(func_graph_node);
|
||||
return func_graph != nullptr && IsSubgraphNameCorrect(func_graph, CONTROL_PATTERN);
|
||||
}
|
||||
|
||||
bool IsLoopBodyReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
return IsSubgraphNameCorrect(func_graph, LOOP_BODY_PATTERN) && node == func_graph->get_return();
|
||||
}
|
||||
|
||||
bool IsAfterLoopReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
return IsSubgraphNameCorrect(func_graph, AFTER_LOOP_PATTERN) && node == func_graph->get_return();
|
||||
}
|
||||
|
||||
struct LoopParts {
|
||||
LoopConditionInfo loop_condition_info;
|
||||
std::vector<std::pair<size_t, size_t>> after_param_to_output_indices;
|
||||
std::vector<size_t> ignored_loop_param_indices;
|
||||
std::vector<std::pair<size_t, size_t>> used_loop_to_control_param_indices;
|
||||
CNodePtr repeat_node;
|
||||
FuncGraphPtr loop_subgraph;
|
||||
FuncGraphPtr after_loop_subgraph;
|
||||
};
|
||||
|
||||
LoopParts MatchGraph(const CNodePtr &start_node) {
|
||||
LoopParts result;
|
||||
|
||||
auto control_subgraph_value = dyn_cast<ValueNode>(start_node->input(0));
|
||||
MS_EXCEPTION_IF_NULL(control_subgraph_value);
|
||||
auto control_subgraph = dyn_cast<FuncGraph>(control_subgraph_value->value());
|
||||
MS_EXCEPTION_IF_NULL(control_subgraph);
|
||||
|
||||
auto switch_node = FindLoopSwitchNode(control_subgraph);
|
||||
auto cond_node = GetNodeInput<CNode>(switch_node, kOneNum);
|
||||
|
||||
result.loop_subgraph = GetSubgraph(switch_node, LOOP_BODY_INPUT, LOOP_BODY_PATTERN);
|
||||
|
||||
result.repeat_node = FindLoopRepeatNode(result.loop_subgraph, control_subgraph);
|
||||
result.loop_condition_info = TraceLoopConditionInfo(start_node, cond_node, control_subgraph, result.repeat_node);
|
||||
|
||||
result.after_loop_subgraph = GetSubgraph(switch_node, AFTER_LOOP_INPUT, AFTER_LOOP_PATTERN);
|
||||
|
||||
auto loop_to_control_order_map = TraceLoopToControlMap(control_subgraph);
|
||||
auto ignored_loop_params_mask = TraceIgnoredLoopParams(start_node, loop_to_control_order_map);
|
||||
auto loop_inputs_num = start_node->inputs().size() - 1;
|
||||
for (size_t i = 0; i < loop_inputs_num; ++i) {
|
||||
if (ignored_loop_params_mask.at(i)) {
|
||||
result.ignored_loop_param_indices.push_back(i);
|
||||
} else {
|
||||
result.used_loop_to_control_param_indices.push_back(std::make_pair(i, loop_to_control_order_map.at(i)));
|
||||
}
|
||||
}
|
||||
|
||||
auto after_to_loop_order_map = TraceAfterToLoopMap(control_subgraph);
|
||||
for (size_t after_i = 0; after_i < result.after_loop_subgraph->parameters().size(); ++after_i) {
|
||||
auto loop_i = after_to_loop_order_map.at(after_i);
|
||||
if (!ignored_loop_params_mask.at(loop_i)) {
|
||||
auto output_i = loop_i;
|
||||
for (size_t i = 0; i < loop_i; ++i) {
|
||||
output_i -= ignored_loop_params_mask.at(i);
|
||||
}
|
||||
result.after_param_to_output_indices.push_back(std::make_pair(after_i, output_i));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace while_loop_export
|
||||
|
||||
class OpAttrInfo {
|
||||
public:
|
||||
OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
|
||||
|
@ -888,6 +1118,8 @@ class OnnxExporter {
|
|||
|
||||
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
|
||||
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
|
@ -1056,6 +1288,13 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfN
|
|||
// export computational nodes and output nodes
|
||||
ExportNodes(func_graph, node_map_ptr, graph_proto);
|
||||
|
||||
// add names for easier debugging
|
||||
for (auto &node : *graph_proto->mutable_node()) {
|
||||
if (!node.has_name()) {
|
||||
node.set_name(node.output(0) + node.op_type());
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString();
|
||||
}
|
||||
|
||||
|
@ -1179,6 +1418,10 @@ struct MergeRule {
|
|||
void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
|
||||
auto &op_merged_infos = *op_merged_infos_ptr;
|
||||
const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
|
||||
op_merged_infos[node].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[node].referred_count -= 1;
|
||||
};
|
||||
|
||||
const std::vector<MergeRule> first_input_merge_rules = {
|
||||
{prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
|
||||
|
@ -1199,20 +1442,30 @@ void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNode
|
|||
MS_LOG(EXCEPTION) << "Multiple outputs for node \"" << cnode->input(1)->ToString() << "\" are not supported";
|
||||
}
|
||||
op_merged_infos[cnode].mode = rule->merge_mode;
|
||||
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||
ignore(cnode->input(1));
|
||||
} else if (while_loop_export::IsLoopBodyReturnNode(cnode, func_graph)) {
|
||||
// Ignore to replace with other outputs
|
||||
ignore(cnode);
|
||||
auto repeat_node = dyn_cast<CNode>(GetRealInput(cnode->input(1)));
|
||||
MS_EXCEPTION_IF_NULL(repeat_node);
|
||||
ignore(repeat_node);
|
||||
} else if (while_loop_export::IsAfterLoopReturnNode(cnode, func_graph)) {
|
||||
// Ignore to inline after-loop subgraph in main graph
|
||||
ignore(cnode);
|
||||
auto first_input = GetRealInput(cnode->input(1));
|
||||
if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
|
||||
ignore(first_input);
|
||||
}
|
||||
} else if (cnode == func_graph->get_return()) {
|
||||
auto first_input = GetRealInput(cnode->input(1)); // Unpack Depend
|
||||
if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
|
||||
// Ignore MakeTuple output node to avoid exporting it to SequenceConstruct
|
||||
// and handle multiple outputs in ExportOutput
|
||||
op_merged_infos[first_input].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[first_input].referred_count -= 1;
|
||||
ignore(first_input);
|
||||
}
|
||||
} else if (cnode->IsApply(prim::kPrimConcat) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
|
||||
// Ignore MakeTuple to handle it in ExportPrimConcat
|
||||
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||
ignore(cnode->input(1));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2888,16 +3141,115 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
op_inputs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
auto op_value = dyn_cast<ValueNode>(op);
|
||||
if (op_value == nullptr) {
|
||||
|
||||
if (!op->isa<ValueNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name();
|
||||
}
|
||||
auto prim = dyn_cast<Primitive>(op_value->value());
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name();
|
||||
|
||||
auto op_value = dyn_cast<ValueNode>(op)->value();
|
||||
if (op_value->isa<Primitive>()) {
|
||||
auto prim = dyn_cast<Primitive>(op_value);
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
||||
} else if (while_loop_export::IsControlSubgraph(op_value)) {
|
||||
ExportWhileLoop(node, node_map_ptr, graph_proto);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Need to support node op value type " << op_value->type_name();
|
||||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto) {
|
||||
auto node_name = RegisterNodeWithUniqueName(start_node, node_map_ptr);
|
||||
auto loop_parts = while_loop_export::MatchGraph(start_node);
|
||||
|
||||
// 1. Make Loop op
|
||||
|
||||
onnx::NodeProto *loop_proto = graph_proto->add_node();
|
||||
loop_proto->set_op_type("Loop");
|
||||
|
||||
auto loop_count_name = node_name + "_M";
|
||||
const auto &loop_counter_params = loop_parts.loop_condition_info;
|
||||
int64_t loop_count = (loop_counter_params.end - loop_counter_params.begin) / loop_counter_params.step;
|
||||
onnx::TensorProto *loop_count_proto = graph_proto->add_initializer();
|
||||
loop_count_proto->set_name(loop_count_name);
|
||||
loop_count_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
loop_count_proto->add_int64_data(loop_count);
|
||||
|
||||
auto loop_cond_name = node_name + "_cond";
|
||||
auto *cond_value = graph_proto->add_initializer();
|
||||
cond_value->set_name(loop_cond_name);
|
||||
cond_value->set_data_type(onnx::TensorProto_DataType_BOOL);
|
||||
cond_value->add_int32_data(true);
|
||||
|
||||
loop_proto->add_input(loop_count_name);
|
||||
loop_proto->add_input(loop_cond_name);
|
||||
for (const auto &[loop_i, control_i] : loop_parts.used_loop_to_control_param_indices) {
|
||||
auto name = GetNodeInputName(start_node->input(control_i + 1), node_map_ptr, graph_proto);
|
||||
loop_proto->add_input(name);
|
||||
loop_proto->add_output(MakeOutputName(node_name + "_loop", loop_i));
|
||||
}
|
||||
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
||||
onnx::AttributeProto *subgraph_attr = loop_proto->add_attribute();
|
||||
subgraph_attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
||||
subgraph_attr->set_name("body");
|
||||
onnx::GraphProto *loop_subgraph_proto = subgraph_attr->mutable_g();
|
||||
|
||||
// 2. Create subgraph for loop body
|
||||
|
||||
auto subgraph_name = loop_parts.loop_subgraph->ToString();
|
||||
auto subgraph_input_cond_name = subgraph_name + "_input_cond";
|
||||
|
||||
auto *iter_num_input = loop_subgraph_proto->add_input();
|
||||
iter_num_input->set_name(subgraph_name + "_input_M");
|
||||
(void)iter_num_input->mutable_type()->mutable_tensor_type()->mutable_shape(); // side-effect: shape created
|
||||
iter_num_input->mutable_type()->mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_INT64);
|
||||
|
||||
auto *cond_input = loop_subgraph_proto->add_input();
|
||||
cond_input->set_name(subgraph_input_cond_name);
|
||||
cond_input->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
|
||||
|
||||
auto *cond_output = loop_subgraph_proto->add_output();
|
||||
cond_output->set_name(cond_input->name());
|
||||
cond_output->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
|
||||
for (size_t i : loop_parts.ignored_loop_param_indices) {
|
||||
const auto ¶m = loop_parts.loop_subgraph->parameters().at(i);
|
||||
renamed_node_map_[param] = "";
|
||||
}
|
||||
|
||||
// Export everything except the control call and the output (see MatchAndMark)
|
||||
ExportFuncGraph(loop_parts.loop_subgraph, node_map_ptr, loop_subgraph_proto);
|
||||
|
||||
// Export outputs manually
|
||||
for (const auto &loop_to_control_i : loop_parts.used_loop_to_control_param_indices) {
|
||||
const auto &input = loop_parts.repeat_node->input(loop_to_control_i.second + 1);
|
||||
ExportOutput(loop_parts.loop_subgraph, input, node_map_ptr, loop_subgraph_proto);
|
||||
}
|
||||
renamed_node_map_.clear();
|
||||
|
||||
// 3. Export part after loop
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
|
||||
const auto &after_loop_params = loop_parts.after_loop_subgraph->parameters();
|
||||
for (const auto &[after_i, output_i] : loop_parts.after_param_to_output_indices) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(static_cast<int>(output_i) < loop_proto->output_size(), "Output index out of bounds");
|
||||
renamed_node_map_[after_loop_params.at(after_i)] = loop_proto->output(output_i);
|
||||
}
|
||||
ExportFuncGraph(loop_parts.after_loop_subgraph, node_map_ptr, graph_proto, false);
|
||||
|
||||
auto after_loop_retval = GetRealInput(loop_parts.after_loop_subgraph->get_return()->input(1));
|
||||
if (after_loop_retval->isa<CNode>() && after_loop_retval->cast<CNodePtr>()->IsApply(prim::kPrimMakeTuple)) {
|
||||
auto tuple_retval = dyn_cast<CNode>(after_loop_retval);
|
||||
for (size_t i = 1; i < tuple_retval->inputs().size(); ++i) {
|
||||
auto output_name = GetNodeInputName(tuple_retval->input(i), node_map_ptr, graph_proto);
|
||||
AddOp("Identity", {output_name}, {MakeOutputName(node_name, i - 1)}, graph_proto);
|
||||
}
|
||||
} else {
|
||||
auto output_name = GetNodeInputName(after_loop_retval, node_map_ptr, graph_proto);
|
||||
AddOp("Identity", {output_name}, {node_name}, graph_proto);
|
||||
}
|
||||
renamed_node_map_.clear();
|
||||
}
|
||||
|
||||
onnx::TensorProto_DataType OnnxExporter::GetOutputType(const AnfNodePtr &node, int64_t output_index) {
|
||||
|
@ -3223,8 +3575,10 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
|
|||
onnx::GraphProto *const graph_proto) {
|
||||
auto node = GetRealInput(orig_node);
|
||||
|
||||
// if node is renamed and not ignored, use alternative name
|
||||
// if it is ignored, try to find the actual name in global map
|
||||
auto renamed_iter = renamed_node_map_.find(node);
|
||||
if (renamed_iter != renamed_node_map_.end()) {
|
||||
if (renamed_iter != renamed_node_map_.end() && renamed_iter->second != "") {
|
||||
return renamed_iter->second;
|
||||
}
|
||||
|
||||
|
@ -3238,11 +3592,12 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
|
|||
}
|
||||
|
||||
// for ValueNode or Parameter with default input, create an initializer
|
||||
// same value can be used in several subgraphs, so create initializers in root graph
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
|
||||
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
|
||||
onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
|
||||
initializer_proto->set_name(node_name);
|
||||
SetTensorData(value, initializer_proto);
|
||||
|
||||
|
@ -3254,7 +3609,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
|
|||
auto param = dyn_cast<Parameter>(node);
|
||||
auto node_name = GenerateUniqueParameterName(param, node_map_ptr);
|
||||
|
||||
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
|
||||
onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
|
||||
initializer_proto->set_name(node_name);
|
||||
SetTensorData(param->default_param(), initializer_proto);
|
||||
|
||||
|
|
Loading…
Reference in New Issue