!30503 ONNX converter: GNMT support part two

Merge pull request !30503 from amalyshev/onnx-gnmt-part-two
This commit is contained in:
i-robot 2022-03-11 01:26:21 +00:00 committed by Gitee
commit fab3dcf03d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 370 additions and 15 deletions

View File

@ -541,6 +541,236 @@ void ClipPointsComponent(const std::string &points, const std::string &clipped,
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
}
namespace while_loop_export {
namespace {
const char CONTROL_PATTERN[] = "";
const char LOOP_BODY_PATTERN[] = "";
const char AFTER_LOOP_PATTERN[] = "";
const size_t LOOP_BODY_INPUT = 2;
const size_t AFTER_LOOP_INPUT = 3;
bool IsSubgraphNameCorrect(const FuncGraphPtr &func_graph, const std::string &part_pattern) {
auto name = func_graph->ToString();
return name.find("construct") != std::string::npos && name.find(part_pattern) != std::string::npos;
}
template <typename T>
const std::shared_ptr<T> GetNodeInput(const CNodePtr &node, size_t i) {
auto input = GetRealInput(node->input(i));
auto result = dyn_cast<T>(input);
if (result == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get input " << i << " of node " << node->DebugString();
}
return result;
}
template <typename T>
const std::shared_ptr<T> GetNodeInputValue(const CNodePtr &node, size_t i) {
auto input = GetNodeInput<ValueNode>(node, i);
auto result = dyn_cast<T>(input->value());
if (result == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get a value from input " << i << " of node " << node->DebugString();
}
return result;
}
CNodePtr FindLoopSwitchNode(const FuncGraphPtr &control_subgraph) {
if (!IsSubgraphNameCorrect(control_subgraph, CONTROL_PATTERN)) {
MS_LOG(EXCEPTION) << "Expected a loop control structure";
}
auto lazy_call_node = GetNodeInput<CNode>(control_subgraph->get_return(), kOneNum);
if (lazy_call_node->inputs().size() != kOneNum || !lazy_call_node->input(kZeroNum)->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Expected a lazy call node";
}
auto switch_node = GetNodeInput<CNode>(lazy_call_node, kZeroNum);
if (!switch_node->IsApply(prim::kPrimSwitch)) {
MS_LOG(EXCEPTION) << "Expected a switch node";
}
return switch_node;
}
FuncGraphPtr GetSubgraph(const CNodePtr &switch_node, size_t input_index, const std::string &name_pattern) {
auto input_node = GetNodeInput<CNode>(switch_node, input_index);
if (!input_node->IsApply(prim::kPrimPartial)) {
MS_LOG(EXCEPTION) << "Expected a partial node";
}
auto subgraph = GetNodeInputValue<FuncGraph>(input_node, kOneNum);
if (!IsSubgraphNameCorrect(subgraph, name_pattern)) {
MS_LOG(EXCEPTION) << "Expected a loop part: " << name_pattern;
}
return subgraph;
}
// The inputs of this node are the outputs of ONNX Loop
CNodePtr FindLoopRepeatNode(const FuncGraphPtr &loop_subgraph, const FuncGraphPtr &control_subgraph) {
auto repeat_node = GetNodeInput<CNode>(loop_subgraph->return_node(), kOneNum);
auto maybe_control_graph = GetNodeInputValue<FuncGraph>(repeat_node, kZeroNum);
MS_EXCEPTION_IF_CHECK_FAIL(maybe_control_graph == control_subgraph, "Loop matching failed");
return repeat_node;
}
struct LoopConditionInfo {
int64_t begin;
int64_t end;
int64_t step;
};
/*
NOTE: loop support is currently very limited, because proper condition export requires more graph surgery (copying
condition expression before and inside Loop subgraph)
The only while loop form supported currently is the one used in GNMT v2's Beam Search. Python example:
i = begin
while i < end:
# ...
i += step
To enable proper support for arbitrary while loop contitions, condition calculation should be duplicated inside the
Loop supgraph. But exporting the same ops twice with different names is not currently supported.
*/
LoopConditionInfo TraceLoopConditionInfo(const CNodePtr &start_node, const CNodePtr &cond_node,
const FuncGraphPtr &control_subgraph, const CNodePtr &loop_repeat_node) {
MS_EXCEPTION_IF_CHECK_FAIL(cond_node->IsApply(prim::kPrimLess), "Expected Less node");
auto counter = GetNodeInput<Parameter>(cond_node, kOneNum);
auto end_tensor = GetNodeInputValue<tensor::Tensor>(cond_node, kTwoNum);
MS_EXCEPTION_IF_CHECK_FAIL(end_tensor->shape_c().empty(), "Expected a scalar tensor");
auto end = *reinterpret_cast<const int32_t *>(end_tensor->data_c());
const auto &subgraph_args = control_subgraph->parameters();
auto counter_input_pos = std::find(subgraph_args.begin(), subgraph_args.end(), counter) - subgraph_args.begin();
auto begin_tensor = GetNodeInputValue<tensor::Tensor>(start_node, 1 + counter_input_pos);
MS_EXCEPTION_IF_CHECK_FAIL(begin_tensor->shape_c().empty(), "Expected a scalar tensor");
auto begin = *reinterpret_cast<const int32_t *>(begin_tensor->data_c());
auto increment_node = GetNodeInput<CNode>(loop_repeat_node, 1 + counter_input_pos);
MS_EXCEPTION_IF_CHECK_FAIL(increment_node->IsApply(prim::kPrimAdd), "Expected Add node");
auto step_tensor = GetNodeInputValue<tensor::Tensor>(increment_node, kTwoNum);
MS_EXCEPTION_IF_CHECK_FAIL(step_tensor->shape_c().empty(), "Expected a scalar tensor");
auto step = *reinterpret_cast<const int32_t *>(step_tensor->data_c());
return LoopConditionInfo{begin, end, step};
}
// result[i] is which control subgraph input should be taken for pos i to match the order of loop subgraph inputs
std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph) {
std::vector<size_t> result;
auto switch_node = FindLoopSwitchNode(control_subgraph);
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
const auto &control_params = control_subgraph->parameters();
auto auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < loop_partial_node->inputs().size(); ++i) {
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
auto control_param_pos =
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
result.push_back(control_param_pos);
}
return result;
}
std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
std::vector<size_t> result;
auto switch_node = FindLoopSwitchNode(control_subgraph);
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
const auto &loop_params = loop_partial_node->inputs();
auto auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < after_partial_node->inputs().size(); ++i) {
auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
result.push_back(after_param_pos - auxiliary_inputs_num);
}
return result;
}
std::vector<bool> TraceIgnoredLoopParams(const CNodePtr &start_node, const std::vector<size_t> &loop_to_control_map) {
auto inputs_num = start_node->inputs().size() - 1;
std::vector<bool> result(inputs_num);
for (size_t loop_i = 0; loop_i < inputs_num; ++loop_i) {
auto control_i = loop_to_control_map.at(loop_i);
const auto &input = start_node->input(control_i + 1);
if ((input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) || HasAbstractMonad(input)) {
result.at(loop_i) = true;
}
}
return result;
}
} // namespace
bool IsControlSubgraph(const ValuePtr &func_graph_node) {
auto func_graph = dyn_cast<FuncGraph>(func_graph_node);
return func_graph != nullptr && IsSubgraphNameCorrect(func_graph, CONTROL_PATTERN);
}
bool IsLoopBodyReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
return IsSubgraphNameCorrect(func_graph, LOOP_BODY_PATTERN) && node == func_graph->get_return();
}
bool IsAfterLoopReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
return IsSubgraphNameCorrect(func_graph, AFTER_LOOP_PATTERN) && node == func_graph->get_return();
}
struct LoopParts {
LoopConditionInfo loop_condition_info;
std::vector<std::pair<size_t, size_t>> after_param_to_output_indices;
std::vector<size_t> ignored_loop_param_indices;
std::vector<std::pair<size_t, size_t>> used_loop_to_control_param_indices;
CNodePtr repeat_node;
FuncGraphPtr loop_subgraph;
FuncGraphPtr after_loop_subgraph;
};
LoopParts MatchGraph(const CNodePtr &start_node) {
LoopParts result;
auto control_subgraph_value = dyn_cast<ValueNode>(start_node->input(0));
MS_EXCEPTION_IF_NULL(control_subgraph_value);
auto control_subgraph = dyn_cast<FuncGraph>(control_subgraph_value->value());
MS_EXCEPTION_IF_NULL(control_subgraph);
auto switch_node = FindLoopSwitchNode(control_subgraph);
auto cond_node = GetNodeInput<CNode>(switch_node, kOneNum);
result.loop_subgraph = GetSubgraph(switch_node, LOOP_BODY_INPUT, LOOP_BODY_PATTERN);
result.repeat_node = FindLoopRepeatNode(result.loop_subgraph, control_subgraph);
result.loop_condition_info = TraceLoopConditionInfo(start_node, cond_node, control_subgraph, result.repeat_node);
result.after_loop_subgraph = GetSubgraph(switch_node, AFTER_LOOP_INPUT, AFTER_LOOP_PATTERN);
auto loop_to_control_order_map = TraceLoopToControlMap(control_subgraph);
auto ignored_loop_params_mask = TraceIgnoredLoopParams(start_node, loop_to_control_order_map);
auto loop_inputs_num = start_node->inputs().size() - 1;
for (size_t i = 0; i < loop_inputs_num; ++i) {
if (ignored_loop_params_mask.at(i)) {
result.ignored_loop_param_indices.push_back(i);
} else {
result.used_loop_to_control_param_indices.push_back(std::make_pair(i, loop_to_control_order_map.at(i)));
}
}
auto after_to_loop_order_map = TraceAfterToLoopMap(control_subgraph);
for (size_t after_i = 0; after_i < result.after_loop_subgraph->parameters().size(); ++after_i) {
auto loop_i = after_to_loop_order_map.at(after_i);
if (!ignored_loop_params_mask.at(loop_i)) {
auto output_i = loop_i;
for (size_t i = 0; i < loop_i; ++i) {
output_i -= ignored_loop_params_mask.at(i);
}
result.after_param_to_output_indices.push_back(std::make_pair(after_i, output_i));
}
}
return result;
}
} // namespace while_loop_export
class OpAttrInfo {
public:
OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
@ -888,6 +1118,8 @@ class OnnxExporter {
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
@ -1056,6 +1288,13 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfN
// export computational nodes and output nodes
ExportNodes(func_graph, node_map_ptr, graph_proto);
// add names for easier debugging
for (auto &node : *graph_proto->mutable_node()) {
if (!node.has_name()) {
node.set_name(node.output(0) + node.op_type());
}
}
MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString();
}
@ -1179,6 +1418,10 @@ struct MergeRule {
void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
auto &op_merged_infos = *op_merged_infos_ptr;
const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
op_merged_infos[node].mode = OP_MERGE_IGNORE;
op_merged_infos[node].referred_count -= 1;
};
const std::vector<MergeRule> first_input_merge_rules = {
{prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
@ -1199,20 +1442,30 @@ void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNode
MS_LOG(EXCEPTION) << "Multiple outputs for node \"" << cnode->input(1)->ToString() << "\" are not supported";
}
op_merged_infos[cnode].mode = rule->merge_mode;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
ignore(cnode->input(1));
} else if (while_loop_export::IsLoopBodyReturnNode(cnode, func_graph)) {
// Ignore to replace with other outputs
ignore(cnode);
auto repeat_node = dyn_cast<CNode>(GetRealInput(cnode->input(1)));
MS_EXCEPTION_IF_NULL(repeat_node);
ignore(repeat_node);
} else if (while_loop_export::IsAfterLoopReturnNode(cnode, func_graph)) {
// Ignore to inline after-loop subgraph in main graph
ignore(cnode);
auto first_input = GetRealInput(cnode->input(1));
if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
ignore(first_input);
}
} else if (cnode == func_graph->get_return()) {
auto first_input = GetRealInput(cnode->input(1)); // Unpack Depend
if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
// Ignore MakeTuple output node to avoid exporting it to SequenceConstruct
// and handle multiple outputs in ExportOutput
op_merged_infos[first_input].mode = OP_MERGE_IGNORE;
op_merged_infos[first_input].referred_count -= 1;
ignore(first_input);
}
} else if (cnode->IsApply(prim::kPrimConcat) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
// Ignore MakeTuple to handle it in ExportPrimConcat
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
ignore(cnode->input(1));
}
}
@ -2888,16 +3141,115 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
op_inputs.push_back(inputs[i]);
}
}
auto op_value = dyn_cast<ValueNode>(op);
if (op_value == nullptr) {
if (!op->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name();
}
auto prim = dyn_cast<Primitive>(op_value->value());
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name();
auto op_value = dyn_cast<ValueNode>(op)->value();
if (op_value->isa<Primitive>()) {
auto prim = dyn_cast<Primitive>(op_value);
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
} else if (while_loop_export::IsControlSubgraph(op_value)) {
ExportWhileLoop(node, node_map_ptr, graph_proto);
} else {
MS_LOG(EXCEPTION) << "Need to support node op value type " << op_value->type_name();
}
}
void OnnxExporter::ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *graph_proto) {
auto node_name = RegisterNodeWithUniqueName(start_node, node_map_ptr);
auto loop_parts = while_loop_export::MatchGraph(start_node);
// 1. Make Loop op
onnx::NodeProto *loop_proto = graph_proto->add_node();
loop_proto->set_op_type("Loop");
auto loop_count_name = node_name + "_M";
const auto &loop_counter_params = loop_parts.loop_condition_info;
int64_t loop_count = (loop_counter_params.end - loop_counter_params.begin) / loop_counter_params.step;
onnx::TensorProto *loop_count_proto = graph_proto->add_initializer();
loop_count_proto->set_name(loop_count_name);
loop_count_proto->set_data_type(onnx::TensorProto_DataType_INT64);
loop_count_proto->add_int64_data(loop_count);
auto loop_cond_name = node_name + "_cond";
auto *cond_value = graph_proto->add_initializer();
cond_value->set_name(loop_cond_name);
cond_value->set_data_type(onnx::TensorProto_DataType_BOOL);
cond_value->add_int32_data(true);
loop_proto->add_input(loop_count_name);
loop_proto->add_input(loop_cond_name);
for (const auto &[loop_i, control_i] : loop_parts.used_loop_to_control_param_indices) {
auto name = GetNodeInputName(start_node->input(control_i + 1), node_map_ptr, graph_proto);
loop_proto->add_input(name);
loop_proto->add_output(MakeOutputName(node_name + "_loop", loop_i));
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
onnx::AttributeProto *subgraph_attr = loop_proto->add_attribute();
subgraph_attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
subgraph_attr->set_name("body");
onnx::GraphProto *loop_subgraph_proto = subgraph_attr->mutable_g();
// 2. Create subgraph for loop body
auto subgraph_name = loop_parts.loop_subgraph->ToString();
auto subgraph_input_cond_name = subgraph_name + "_input_cond";
auto *iter_num_input = loop_subgraph_proto->add_input();
iter_num_input->set_name(subgraph_name + "_input_M");
(void)iter_num_input->mutable_type()->mutable_tensor_type()->mutable_shape(); // side-effect: shape created
iter_num_input->mutable_type()->mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_INT64);
auto *cond_input = loop_subgraph_proto->add_input();
cond_input->set_name(subgraph_input_cond_name);
cond_input->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
auto *cond_output = loop_subgraph_proto->add_output();
cond_output->set_name(cond_input->name());
cond_output->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
for (size_t i : loop_parts.ignored_loop_param_indices) {
const auto &param = loop_parts.loop_subgraph->parameters().at(i);
renamed_node_map_[param] = "";
}
// Export everything except the control call and the output (see MatchAndMark)
ExportFuncGraph(loop_parts.loop_subgraph, node_map_ptr, loop_subgraph_proto);
// Export outputs manually
for (const auto &loop_to_control_i : loop_parts.used_loop_to_control_param_indices) {
const auto &input = loop_parts.repeat_node->input(loop_to_control_i.second + 1);
ExportOutput(loop_parts.loop_subgraph, input, node_map_ptr, loop_subgraph_proto);
}
renamed_node_map_.clear();
// 3. Export part after loop
MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
const auto &after_loop_params = loop_parts.after_loop_subgraph->parameters();
for (const auto &[after_i, output_i] : loop_parts.after_param_to_output_indices) {
MS_EXCEPTION_IF_CHECK_FAIL(static_cast<int>(output_i) < loop_proto->output_size(), "Output index out of bounds");
renamed_node_map_[after_loop_params.at(after_i)] = loop_proto->output(output_i);
}
ExportFuncGraph(loop_parts.after_loop_subgraph, node_map_ptr, graph_proto, false);
auto after_loop_retval = GetRealInput(loop_parts.after_loop_subgraph->get_return()->input(1));
if (after_loop_retval->isa<CNode>() && after_loop_retval->cast<CNodePtr>()->IsApply(prim::kPrimMakeTuple)) {
auto tuple_retval = dyn_cast<CNode>(after_loop_retval);
for (size_t i = 1; i < tuple_retval->inputs().size(); ++i) {
auto output_name = GetNodeInputName(tuple_retval->input(i), node_map_ptr, graph_proto);
AddOp("Identity", {output_name}, {MakeOutputName(node_name, i - 1)}, graph_proto);
}
} else {
auto output_name = GetNodeInputName(after_loop_retval, node_map_ptr, graph_proto);
AddOp("Identity", {output_name}, {node_name}, graph_proto);
}
renamed_node_map_.clear();
}
onnx::TensorProto_DataType OnnxExporter::GetOutputType(const AnfNodePtr &node, int64_t output_index) {
@ -3223,8 +3575,10 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
onnx::GraphProto *const graph_proto) {
auto node = GetRealInput(orig_node);
// if node is renamed and not ignored, use alternative name
// if it is ignored, try to find the actual name in global map
auto renamed_iter = renamed_node_map_.find(node);
if (renamed_iter != renamed_node_map_.end()) {
if (renamed_iter != renamed_node_map_.end() && renamed_iter->second != "") {
return renamed_iter->second;
}
@ -3238,11 +3592,12 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
}
// for ValueNode or Parameter with default input, create an initializer
// same value can be used in several subgraphs, so create initializers in root graph
if (node->isa<ValueNode>()) {
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
auto value = node->cast<ValueNodePtr>()->value();
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
initializer_proto->set_name(node_name);
SetTensorData(value, initializer_proto);
@ -3254,7 +3609,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map
auto param = dyn_cast<Parameter>(node);
auto node_name = GenerateUniqueParameterName(param, node_map_ptr);
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
initializer_proto->set_name(node_name);
SetTensorData(param->default_param(), initializer_proto);