fix bug of invalid data type or format when set node info

This commit is contained in:
looop5 2021-04-26 19:10:50 +08:00
parent 81cd26bdc8
commit 887d400e3f
2 changed files with 120 additions and 59 deletions

View File

@ -18,7 +18,6 @@
#include <memory>
#include <vector>
#include <string>
#include <algorithm>
#include <unordered_set>
#include "base/core_ops.h"
#include "utils/utils.h"
@ -89,7 +88,7 @@ bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &c
return true;
}
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) {
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
MS_EXCEPTION_IF_NULL(orig_node);
MS_EXCEPTION_IF_NULL(new_node);
@ -100,32 +99,19 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type
}
AbstractBasePtr new_abstract{nullptr};
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type{node_type};
KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)};
kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)};
kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)};
kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)};
auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node);
for (size_t i = 0; i < node_data_inputs_num; ++i) {
auto node_input = AnfAlgo::GetInputNode(new_node, i);
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0);
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0);
inputs_format.push_back(node_input_format);
inputs_device_type.push_back(node_input_type);
if (node_io_info.outputs_type.empty()) {
MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
}
if (node_name == "Cast") {
auto node_input = AnfAlgo::GetInputNode(new_node, 0);
new_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape());
outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0));
MS_EXCEPTION_IF_NULL(node_input);
MS_EXCEPTION_IF_NULL(node_input->abstract());
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
node_input->abstract()->BuildShape());
} else {
new_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape());
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0));
MS_EXCEPTION_IF_NULL(orig_node->abstract());
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
orig_node->abstract()->BuildShape());
}
// Set abstract info
@ -135,14 +121,14 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type
// Set kernel build info
new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
info_builder.SetInputsFormat(inputs_format);
info_builder.SetInputsDeviceType(inputs_device_type);
info_builder.SetOutputsFormat(outputs_format);
info_builder.SetOutputsDeviceType(outputs_device_type);
info_builder.SetKernelType(kernel_type);
info_builder.SetOpPattern(op_pattern);
info_builder.SetFusionType(fusion_type);
info_builder.SetProcessor(processor);
info_builder.SetInputsFormat(node_io_info.inputs_format);
info_builder.SetInputsDeviceType(node_io_info.inputs_type);
info_builder.SetOutputsFormat(node_io_info.outputs_format);
info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
}
} // namespace
@ -156,16 +142,22 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
<< new_input_at_indexes.size();
}
if (!new_inputs->empty()) {
new_inputs->resize(0);
auto node_inputs_num = node->size();
if (node_inputs_num == 0) {
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
}
// node's inputs at indexes change to new_input_at_indexes
if (!new_inputs->empty()) {
new_inputs->resize(0);
}
new_inputs->push_back(node->input(0));
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
auto node_inputs_num = node->size();
size_t idx = 0;
for (size_t i = 0; i < node_inputs_num; ++i) {
if (indexes_set.find(i) == indexes_set.end()) {
for (size_t i = 1; i < node_inputs_num; ++i) {
size_t data_idx = i - 1;
if (indexes_set.find(data_idx) == indexes_set.end()) {
new_inputs->push_back(node->input(i));
} else {
new_inputs->push_back(new_input_at_indexes[idx++]);
@ -173,13 +165,57 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v
}
}
void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
const std::vector<AnfNodePtr> &input_at_indexes,
NodeIOInfo *new_inputs_info, bool from_input) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(new_inputs_info);
if (indexes.size() != input_at_indexes.size()) {
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
<< input_at_indexes.size();
}
auto node_inputs_num = node->size();
if (node_inputs_num == 0) {
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
}
// node's inputs info at indexes change to input_at_indexes's input or output info
new_inputs_info->inputs_format.resize(0);
new_inputs_info->inputs_type.resize(0);
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
size_t idx = 0;
for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
if (indexes_set.find(data_idx) == indexes_set.end()) {
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
} else {
if (from_input) {
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
} else {
new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
}
idx++;
}
}
}
bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node) {
// Limitation: Current cast node is CAST_DOWN.
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) {
// Limitation:
// Current cast node is CAST_DOWN.
// Cast node will not change the input format.
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
return false;
}
auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
auto node_input = AnfAlgo::GetInputNode(node, 0);
auto type_insens_node = node_input->cast<CNodePtr>();
// Limitation:
@ -190,11 +226,9 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph,
return false;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
// Limitation: Type insensitive node's inputs have same data type.
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) {
// Limitation: Type insensitive node's inputs are the large type.
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
return false;
}
@ -202,17 +236,23 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph,
for (const auto &index : op_input_indexes) {
auto new_cast_node =
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
SetNodeInfo(node, new_cast_node, cast_out_type);
NodeIOInfo cast_io_info;
cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
cast_io_info.outputs_format = cast_io_info.inputs_format;
cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
cast_io_info.outputs_type.push_back(small_type);
SetNodeInfo(node, new_cast_node, cast_io_info);
new_cast_nodes.push_back(new_cast_node);
}
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
[](const size_t &idx) { return idx + 1; });
std::vector<AnfNodePtr> type_insens_node_new_inputs;
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
NodeIOInfo type_insens_io_info;
type_insens_io_info.outputs_format.push_back(pattern_output_format);
type_insens_io_info.outputs_type.push_back(small_type);
SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type);
SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
(void)mng->Replace(node, new_type_insens_node);
return true;
@ -227,14 +267,16 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co
// Limitation:
// Certain inputs of type insensitive node are cast node.
// Cast nodes are CAST_UP.
// Cast nodes will not change the input format.
// All these cast nodes are only used by current type insensitive node.
std::vector<CNodePtr> cast_nodes;
std::vector<AnfNodePtr> cast_nodes;
std::vector<AnfNodePtr> cast_input_nodes;
auto op_input_indexes = GetOpDataInputIndexes(node);
for (const auto &index : op_input_indexes) {
auto node_input = AnfAlgo::GetInputNode(node, index);
auto cast_node = node_input->cast<CNodePtr>();
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
mng->node_users()[cast_node].size() == 1) {
cast_nodes.push_back(cast_node);
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
@ -244,29 +286,37 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co
return false;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
// Limitation: All these cast nodes cast same type to another type.
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) {
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type;
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
})) {
return false;
}
// Limitation: Type insensitive node's inputs have same data type.
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) {
if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
return false;
}
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
[](const size_t &idx) { return idx + 1; });
std::vector<AnfNodePtr> type_insens_node_new_inputs;
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
SetNodeInfo(node, new_type_insens_node, cast_input_type);
NodeIOInfo type_insens_io_info;
type_insens_io_info.outputs_format.push_back(pattern_output_format);
type_insens_io_info.outputs_type.push_back(small_type);
SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type);
NodeIOInfo cast_io_info;
cast_io_info.inputs_format.push_back(pattern_output_format);
cast_io_info.outputs_format = cast_io_info.inputs_format;
cast_io_info.inputs_type.push_back(small_type);
cast_io_info.outputs_type.push_back(large_type);
SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
(void)mng->Replace(node, new_cast_node);
return true;

View File

@ -19,10 +19,18 @@
#include <memory>
#include <vector>
#include <string>
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
struct NodeIOInfo {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
};
class ReorderOps : public Pass {
public:
ReorderOps() : Pass("reorder_ops") {}
@ -33,6 +41,9 @@ class ReorderOps : public Pass {
void SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
const std::vector<AnfNodePtr> &new_input_in_indexes,
std::vector<AnfNodePtr> *new_inputs);
void SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
const std::vector<AnfNodePtr> &input_at_indexes, NodeIOInfo *new_inputs_info,
bool from_input);
bool ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node);
bool ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,