forked from mindspore-Ecosystem/mindspore
fix bug of invalid data type or format when set node info
This commit is contained in:
parent
81cd26bdc8
commit
887d400e3f
|
@ -18,7 +18,6 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -89,7 +88,7 @@ bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &c
|
|||
return true;
|
||||
}
|
||||
|
||||
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) {
|
||||
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
||||
|
@ -100,32 +99,19 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type
|
|||
}
|
||||
|
||||
AbstractBasePtr new_abstract{nullptr};
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type{node_type};
|
||||
KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)};
|
||||
kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)};
|
||||
kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)};
|
||||
kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)};
|
||||
|
||||
auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node);
|
||||
for (size_t i = 0; i < node_data_inputs_num; ++i) {
|
||||
auto node_input = AnfAlgo::GetInputNode(new_node, i);
|
||||
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0);
|
||||
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0);
|
||||
inputs_format.push_back(node_input_format);
|
||||
inputs_device_type.push_back(node_input_type);
|
||||
if (node_io_info.outputs_type.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
|
||||
}
|
||||
if (node_name == "Cast") {
|
||||
auto node_input = AnfAlgo::GetInputNode(new_node, 0);
|
||||
new_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape());
|
||||
outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0));
|
||||
MS_EXCEPTION_IF_NULL(node_input);
|
||||
MS_EXCEPTION_IF_NULL(node_input->abstract());
|
||||
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
|
||||
node_input->abstract()->BuildShape());
|
||||
} else {
|
||||
new_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape());
|
||||
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0));
|
||||
MS_EXCEPTION_IF_NULL(orig_node->abstract());
|
||||
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
|
||||
orig_node->abstract()->BuildShape());
|
||||
}
|
||||
|
||||
// Set abstract info
|
||||
|
@ -135,14 +121,14 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type
|
|||
// Set kernel build info
|
||||
new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
|
||||
info_builder.SetInputsFormat(inputs_format);
|
||||
info_builder.SetInputsDeviceType(inputs_device_type);
|
||||
info_builder.SetOutputsFormat(outputs_format);
|
||||
info_builder.SetOutputsDeviceType(outputs_device_type);
|
||||
info_builder.SetKernelType(kernel_type);
|
||||
info_builder.SetOpPattern(op_pattern);
|
||||
info_builder.SetFusionType(fusion_type);
|
||||
info_builder.SetProcessor(processor);
|
||||
info_builder.SetInputsFormat(node_io_info.inputs_format);
|
||||
info_builder.SetInputsDeviceType(node_io_info.inputs_type);
|
||||
info_builder.SetOutputsFormat(node_io_info.outputs_format);
|
||||
info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
|
||||
info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
|
||||
info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
|
||||
info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
|
||||
info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
@ -156,16 +142,22 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v
|
|||
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
|
||||
<< new_input_at_indexes.size();
|
||||
}
|
||||
if (!new_inputs->empty()) {
|
||||
new_inputs->resize(0);
|
||||
|
||||
auto node_inputs_num = node->size();
|
||||
if (node_inputs_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
// node's inputs at indexes change to new_input_at_indexes
|
||||
if (!new_inputs->empty()) {
|
||||
new_inputs->resize(0);
|
||||
}
|
||||
new_inputs->push_back(node->input(0));
|
||||
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
|
||||
auto node_inputs_num = node->size();
|
||||
size_t idx = 0;
|
||||
for (size_t i = 0; i < node_inputs_num; ++i) {
|
||||
if (indexes_set.find(i) == indexes_set.end()) {
|
||||
for (size_t i = 1; i < node_inputs_num; ++i) {
|
||||
size_t data_idx = i - 1;
|
||||
if (indexes_set.find(data_idx) == indexes_set.end()) {
|
||||
new_inputs->push_back(node->input(i));
|
||||
} else {
|
||||
new_inputs->push_back(new_input_at_indexes[idx++]);
|
||||
|
@ -173,13 +165,57 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v
|
|||
}
|
||||
}
|
||||
|
||||
void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
|
||||
const std::vector<AnfNodePtr> &input_at_indexes,
|
||||
NodeIOInfo *new_inputs_info, bool from_input) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(new_inputs_info);
|
||||
if (indexes.size() != input_at_indexes.size()) {
|
||||
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
|
||||
<< input_at_indexes.size();
|
||||
}
|
||||
|
||||
auto node_inputs_num = node->size();
|
||||
if (node_inputs_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
// node's inputs info at indexes change to input_at_indexes's input or output info
|
||||
new_inputs_info->inputs_format.resize(0);
|
||||
new_inputs_info->inputs_type.resize(0);
|
||||
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
|
||||
size_t idx = 0;
|
||||
for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
|
||||
if (indexes_set.find(data_idx) == indexes_set.end()) {
|
||||
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
|
||||
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
|
||||
} else {
|
||||
if (from_input) {
|
||||
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
|
||||
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
|
||||
} else {
|
||||
new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
|
||||
new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||
const CNodePtr &node) {
|
||||
// Limitation: Current cast node is CAST_DOWN.
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) {
|
||||
// Limitation:
|
||||
// Current cast node is CAST_DOWN.
|
||||
// Cast node will not change the input format.
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
|
||||
AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
|
||||
auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
|
||||
auto node_input = AnfAlgo::GetInputNode(node, 0);
|
||||
auto type_insens_node = node_input->cast<CNodePtr>();
|
||||
// Limitation:
|
||||
|
@ -190,11 +226,9 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph,
|
|||
return false;
|
||||
}
|
||||
|
||||
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
|
||||
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
|
||||
// Limitation: Type insensitive node's inputs have same data type.
|
||||
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) {
|
||||
// Limitation: Type insensitive node's inputs are the large type.
|
||||
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -202,17 +236,23 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph,
|
|||
for (const auto &index : op_input_indexes) {
|
||||
auto new_cast_node =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
|
||||
SetNodeInfo(node, new_cast_node, cast_out_type);
|
||||
NodeIOInfo cast_io_info;
|
||||
cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
|
||||
cast_io_info.outputs_format = cast_io_info.inputs_format;
|
||||
cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
|
||||
cast_io_info.outputs_type.push_back(small_type);
|
||||
SetNodeInfo(node, new_cast_node, cast_io_info);
|
||||
new_cast_nodes.push_back(new_cast_node);
|
||||
}
|
||||
|
||||
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
|
||||
[](const size_t &idx) { return idx + 1; });
|
||||
|
||||
std::vector<AnfNodePtr> type_insens_node_new_inputs;
|
||||
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
|
||||
NodeIOInfo type_insens_io_info;
|
||||
type_insens_io_info.outputs_format.push_back(pattern_output_format);
|
||||
type_insens_io_info.outputs_type.push_back(small_type);
|
||||
SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
|
||||
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
|
||||
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type);
|
||||
SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
|
||||
|
||||
(void)mng->Replace(node, new_type_insens_node);
|
||||
return true;
|
||||
|
@ -227,14 +267,16 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co
|
|||
// Limitation:
|
||||
// Certain inputs of type insensitive node are cast node.
|
||||
// Cast nodes are CAST_UP.
|
||||
// Cast nodes will not change the input format.
|
||||
// All these cast nodes are only used by current type insensitive node.
|
||||
std::vector<CNodePtr> cast_nodes;
|
||||
std::vector<AnfNodePtr> cast_nodes;
|
||||
std::vector<AnfNodePtr> cast_input_nodes;
|
||||
auto op_input_indexes = GetOpDataInputIndexes(node);
|
||||
for (const auto &index : op_input_indexes) {
|
||||
auto node_input = AnfAlgo::GetInputNode(node, index);
|
||||
auto cast_node = node_input->cast<CNodePtr>();
|
||||
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
|
||||
AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
|
||||
mng->node_users()[cast_node].size() == 1) {
|
||||
cast_nodes.push_back(cast_node);
|
||||
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
|
||||
|
@ -244,29 +286,37 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co
|
|||
return false;
|
||||
}
|
||||
|
||||
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
|
||||
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
|
||||
auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
|
||||
auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
|
||||
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
|
||||
// Limitation: All these cast nodes cast same type to another type.
|
||||
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) {
|
||||
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type;
|
||||
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
|
||||
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
// Limitation: Type insensitive node's inputs have same data type.
|
||||
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) {
|
||||
if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
|
||||
[](const size_t &idx) { return idx + 1; });
|
||||
|
||||
std::vector<AnfNodePtr> type_insens_node_new_inputs;
|
||||
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
|
||||
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
|
||||
SetNodeInfo(node, new_type_insens_node, cast_input_type);
|
||||
NodeIOInfo type_insens_io_info;
|
||||
type_insens_io_info.outputs_format.push_back(pattern_output_format);
|
||||
type_insens_io_info.outputs_type.push_back(small_type);
|
||||
SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
|
||||
SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
|
||||
|
||||
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
|
||||
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type);
|
||||
NodeIOInfo cast_io_info;
|
||||
cast_io_info.inputs_format.push_back(pattern_output_format);
|
||||
cast_io_info.outputs_format = cast_io_info.inputs_format;
|
||||
cast_io_info.inputs_type.push_back(small_type);
|
||||
cast_io_info.outputs_type.push_back(large_type);
|
||||
SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
|
||||
|
||||
(void)mng->Replace(node, new_cast_node);
|
||||
return true;
|
||||
|
|
|
@ -19,10 +19,18 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
struct NodeIOInfo {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
};
|
||||
|
||||
class ReorderOps : public Pass {
|
||||
public:
|
||||
ReorderOps() : Pass("reorder_ops") {}
|
||||
|
@ -33,6 +41,9 @@ class ReorderOps : public Pass {
|
|||
void SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
|
||||
const std::vector<AnfNodePtr> &new_input_in_indexes,
|
||||
std::vector<AnfNodePtr> *new_inputs);
|
||||
void SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
|
||||
const std::vector<AnfNodePtr> &input_at_indexes, NodeIOInfo *new_inputs_info,
|
||||
bool from_input);
|
||||
bool ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||
const CNodePtr &node);
|
||||
bool ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||
|
|
Loading…
Reference in New Issue