enhancement reorder_ops pass to support reordering cast and type insensitive operators
support castup, type-insensitive to type-insensitive, castup refactor reorder_ops fix compiling move reorder_ops pass to later fix abstract refactor fix node input num bug
This commit is contained in:
parent
6801ef61e0
commit
e88cdc84ec
|
@ -53,9 +53,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
|
||||||
if (is_ascend) {
|
if (is_ascend) {
|
||||||
// Remove redundant Cast(bias, fp16) for Matmul input
|
// Remove redundant Cast(bias, fp16) for Matmul input
|
||||||
pm->AddPass(std::make_shared<CastMatmulFusion>());
|
pm->AddPass(std::make_shared<CastMatmulFusion>());
|
||||||
|
|
||||||
// Reorder TransData-Cast to Cast-TransData
|
|
||||||
pm->AddPass(std::make_shared<ReorderOps>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Spread the MakeTuple input of UpdateState
|
// Spread the MakeTuple input of UpdateState
|
||||||
|
@ -78,6 +75,9 @@ PassManagerPtr GraphKernelOptimizer::Cluster() {
|
||||||
|
|
||||||
PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() {
|
PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() {
|
||||||
auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1");
|
auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1");
|
||||||
|
// Reorder Cast and Type-insensitive node
|
||||||
|
pm->AddPass(std::make_shared<ReorderOps>());
|
||||||
|
|
||||||
// normalize the Reduce axis
|
// normalize the Reduce axis
|
||||||
pm->AddPass(std::make_shared<AxisNormalizer>());
|
pm->AddPass(std::make_shared<AxisNormalizer>());
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,6 +18,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_set>
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
@ -27,78 +29,111 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
bool CanReorder(const FuncGraphManagerPtr &mng, const CNodePtr &transdata_node, const CNodePtr &cast_node) {
|
bool IsTypeInsensitive(const CNodePtr &node) {
|
||||||
auto transdata_input_type = AnfAlgo::GetInputDeviceDataType(transdata_node, 0);
|
// Nodes that will change the input data type will not seen as type insensitive nodes.
|
||||||
auto transdata_output_type = AnfAlgo::GetOutputDeviceDataType(transdata_node, 0);
|
static std::unordered_set<PrimitivePtr> type_insensitive_op_list{
|
||||||
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0);
|
prim::KPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape,
|
||||||
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
|
prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimRelu,
|
||||||
// Conditions of reordering transdata_cast to cast_transdata:
|
prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect};
|
||||||
// 1) current transdata is only used by cast
|
|
||||||
// 2) transdata works on float32 (transdata supports float16/float32;
|
return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
|
||||||
// transdata performances better on float16 due to less data to process)
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
// 3) cast works on float32 -> float16
|
|
||||||
if (mng->node_users()[transdata_node].size() == 1 && transdata_input_type == kNumberTypeFloat32 &&
|
|
||||||
transdata_output_type == transdata_input_type && cast_input_type == transdata_output_type &&
|
|
||||||
cast_output_type == kNumberTypeFloat16) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, const CNodePtr &node) {
|
enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
|
||||||
// Initial
|
CastType GetCastType(const CNodePtr &node) {
|
||||||
// TransData: (type0, format0) -> (type0, format1)
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
// Cast: (type0, format1) -> (type1, format1)
|
if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||||
// After reorder
|
MS_LOG(EXCEPTION) << "Only process for Cast!";
|
||||||
// Cast: (type0, format0) -> (type1, format0)
|
}
|
||||||
// TransData: (type1, format0) -> (type1, format1)
|
TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
|
||||||
auto type0 = AnfAlgo::GetInputDeviceDataType(transdata_node, 0);
|
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||||
auto type1 = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
|
|
||||||
auto format0 = AnfAlgo::GetInputFormat(transdata_node, 0);
|
|
||||||
auto format1 = AnfAlgo::GetOutputFormat(transdata_node, 0);
|
|
||||||
|
|
||||||
auto abstract = transdata_node->abstract();
|
if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
|
||||||
auto scope = cast_node->scope();
|
return CAST_UP;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
|
||||||
|
return CAST_DOWN;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CAST_OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
|
||||||
|
std::vector<size_t> op_input_indexes;
|
||||||
|
if (node == nullptr || !IsTypeInsensitive(node)) {
|
||||||
|
return op_input_indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data input index starts from 0.
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
|
||||||
|
op_input_indexes = {0, 1};
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
|
||||||
|
op_input_indexes = {1, 2};
|
||||||
|
} else {
|
||||||
|
op_input_indexes = {0};
|
||||||
|
}
|
||||||
|
return op_input_indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
// node's inputs at check_indexes should be of type base_type
|
||||||
|
for (const auto &index : check_indexes) {
|
||||||
|
if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(orig_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
|
|
||||||
|
auto node_name = AnfAlgo::GetCNodeName(new_node);
|
||||||
|
auto orig_node_name = AnfAlgo::GetCNodeName(orig_node);
|
||||||
|
if (orig_node_name != node_name) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr new_abstract{nullptr};
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<std::string> outputs_format;
|
std::vector<std::string> outputs_format;
|
||||||
std::vector<TypeId> inputs_device_type;
|
std::vector<TypeId> inputs_device_type;
|
||||||
std::vector<TypeId> outputs_device_type;
|
std::vector<TypeId> outputs_device_type{node_type};
|
||||||
auto kernel_type = AnfAlgo::GetKernelType(cast_node);
|
KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)};
|
||||||
auto op_pattern = AnfAlgo::GetOpPattern(cast_node);
|
kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)};
|
||||||
auto fusion_type = AnfAlgo::GetFusionType(cast_node);
|
kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)};
|
||||||
auto processor = AnfAlgo::GetProcessor(cast_node);
|
kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)};
|
||||||
|
|
||||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node);
|
||||||
|
for (size_t i = 0; i < node_data_inputs_num; ++i) {
|
||||||
|
auto node_input = AnfAlgo::GetInputNode(new_node, i);
|
||||||
|
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0);
|
||||||
|
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0);
|
||||||
|
inputs_format.push_back(node_input_format);
|
||||||
|
inputs_device_type.push_back(node_input_type);
|
||||||
|
}
|
||||||
if (node_name == "Cast") {
|
if (node_name == "Cast") {
|
||||||
inputs_format.push_back(format0);
|
auto node_input = AnfAlgo::GetInputNode(new_node, 0);
|
||||||
outputs_format.push_back(format0);
|
new_abstract =
|
||||||
inputs_device_type.push_back(type0);
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape());
|
||||||
outputs_device_type.push_back(type1);
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0));
|
||||||
// Set attrs
|
|
||||||
AnfAlgo::CopyNodeAttrs(cast_node, node);
|
|
||||||
} else if (node_name == "TransData") {
|
|
||||||
abstract = cast_node->abstract();
|
|
||||||
scope = transdata_node->scope();
|
|
||||||
inputs_format.push_back(format0);
|
|
||||||
outputs_format.push_back(format1);
|
|
||||||
inputs_device_type.push_back(type1);
|
|
||||||
outputs_device_type.push_back(type1);
|
|
||||||
kernel_type = AnfAlgo::GetKernelType(transdata_node);
|
|
||||||
op_pattern = AnfAlgo::GetOpPattern(transdata_node);
|
|
||||||
fusion_type = AnfAlgo::GetFusionType(transdata_node);
|
|
||||||
processor = AnfAlgo::GetProcessor(transdata_node);
|
|
||||||
// Set attrs
|
|
||||||
AnfAlgo::CopyNodeAttrs(transdata_node, node);
|
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Node must be Cast or TransData";
|
new_abstract =
|
||||||
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape());
|
||||||
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set abstract info
|
// Set abstract info
|
||||||
node->set_abstract(abstract);
|
new_node->set_abstract(new_abstract);
|
||||||
// Set scope info
|
// Set attrs
|
||||||
node->set_scope(scope);
|
AnfAlgo::CopyNodeAttrs(orig_node, new_node);
|
||||||
// Set kernel build info
|
// Set kernel build info
|
||||||
node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
|
||||||
info_builder.SetInputsFormat(inputs_format);
|
info_builder.SetInputsFormat(inputs_format);
|
||||||
info_builder.SetInputsDeviceType(inputs_device_type);
|
info_builder.SetInputsDeviceType(inputs_device_type);
|
||||||
|
@ -108,10 +143,141 @@ void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, cons
|
||||||
info_builder.SetOpPattern(op_pattern);
|
info_builder.SetOpPattern(op_pattern);
|
||||||
info_builder.SetFusionType(fusion_type);
|
info_builder.SetFusionType(fusion_type);
|
||||||
info_builder.SetProcessor(processor);
|
info_builder.SetProcessor(processor);
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
|
||||||
|
const std::vector<AnfNodePtr> &new_input_at_indexes,
|
||||||
|
std::vector<AnfNodePtr> *new_inputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_inputs);
|
||||||
|
if (indexes.size() != new_input_at_indexes.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
|
||||||
|
<< new_input_at_indexes.size();
|
||||||
|
}
|
||||||
|
if (!new_inputs->empty()) {
|
||||||
|
new_inputs->resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// node's inputs at indexes change to new_input_at_indexes
|
||||||
|
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
|
||||||
|
auto node_inputs_num = node->size();
|
||||||
|
size_t idx = 0;
|
||||||
|
for (size_t i = 0; i < node_inputs_num; ++i) {
|
||||||
|
if (indexes_set.find(i) == indexes_set.end()) {
|
||||||
|
new_inputs->push_back(node->input(i));
|
||||||
|
} else {
|
||||||
|
new_inputs->push_back(new_input_at_indexes[idx++]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ReorderTransDataCast(const FuncGraphPtr &func_graph) {
|
bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||||
|
const CNodePtr &node) {
|
||||||
|
// Limitation: Current cast node is CAST_DOWN.
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto node_input = AnfAlgo::GetInputNode(node, 0);
|
||||||
|
auto type_insens_node = node_input->cast<CNodePtr>();
|
||||||
|
// Limitation:
|
||||||
|
// Find type insensitive node before cast node.
|
||||||
|
// Type insensitive node is only used by current cast node.
|
||||||
|
if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
|
||||||
|
mng->node_users()[type_insens_node].size() > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
|
||||||
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||||
|
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
|
||||||
|
// Limitation: Type insensitive node's inputs have same data type.
|
||||||
|
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> new_cast_nodes;
|
||||||
|
for (const auto &index : op_input_indexes) {
|
||||||
|
auto new_cast_node =
|
||||||
|
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
|
||||||
|
SetNodeInfo(node, new_cast_node, cast_out_type);
|
||||||
|
new_cast_nodes.push_back(new_cast_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
|
||||||
|
[](const size_t &idx) { return idx + 1; });
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> type_insens_node_new_inputs;
|
||||||
|
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
|
||||||
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
|
||||||
|
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type);
|
||||||
|
|
||||||
|
(void)mng->Replace(node, new_type_insens_node);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||||
|
const CNodePtr &node) {
|
||||||
|
if (!IsTypeInsensitive(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limitation:
|
||||||
|
// Certain inputs of type insensitive node are cast node.
|
||||||
|
// Cast nodes are CAST_UP.
|
||||||
|
// All these cast nodes are only used by current type insensitive node.
|
||||||
|
std::vector<CNodePtr> cast_nodes;
|
||||||
|
std::vector<AnfNodePtr> cast_input_nodes;
|
||||||
|
auto op_input_indexes = GetOpDataInputIndexes(node);
|
||||||
|
for (const auto &index : op_input_indexes) {
|
||||||
|
auto node_input = AnfAlgo::GetInputNode(node, index);
|
||||||
|
auto cast_node = node_input->cast<CNodePtr>();
|
||||||
|
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
|
||||||
|
mng->node_users()[cast_node].size() == 1) {
|
||||||
|
cast_nodes.push_back(cast_node);
|
||||||
|
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
|
||||||
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
|
||||||
|
// Limitation: All these cast nodes cast same type to another type.
|
||||||
|
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) {
|
||||||
|
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type;
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Limitation: Type insensitive node's inputs have same data type.
|
||||||
|
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
|
||||||
|
[](const size_t &idx) { return idx + 1; });
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> type_insens_node_new_inputs;
|
||||||
|
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
|
||||||
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
|
||||||
|
SetNodeInfo(node, new_type_insens_node, cast_input_type);
|
||||||
|
|
||||||
|
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
|
||||||
|
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type);
|
||||||
|
|
||||||
|
(void)mng->Replace(node, new_cast_node);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) {
|
||||||
|
// Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
|
||||||
|
// see the comments that start will "Limitation:" in this file.
|
||||||
|
// Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
|
||||||
|
// as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
|
||||||
|
// a boolean result.
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto mng = func_graph->manager();
|
auto mng = func_graph->manager();
|
||||||
if (mng == nullptr) {
|
if (mng == nullptr) {
|
||||||
|
@ -121,40 +287,52 @@ bool ReorderTransDataCast(const FuncGraphPtr &func_graph) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto todos = TopoSort(func_graph->get_return());
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
for (const auto &anf_node : todos) {
|
for (const auto &anf_node : todos) {
|
||||||
// Find cast node.
|
auto node = anf_node->cast<CNodePtr>();
|
||||||
auto cast_node = anf_node->cast<CNodePtr>();
|
if (node == nullptr) {
|
||||||
if (cast_node == nullptr || !AnfAlgo::CheckPrimitiveType(cast_node, prim::kPrimCast)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find transdata node before cast node.
|
if (IsTypeInsensitive(node)) {
|
||||||
auto cast_input = AnfAlgo::GetInputNode(cast_node, 0);
|
// Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
|
||||||
auto transdata_node = cast_input->cast<CNodePtr>();
|
changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
|
||||||
if (transdata_node == nullptr || !AnfAlgo::CheckPrimitiveType(transdata_node, prim::KPrimTransData)) {
|
} else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||||
continue;
|
// Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
|
||||||
|
changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reorder transdata_cast to cast_transdata if possible.
|
|
||||||
if (!CanReorder(mng, transdata_node, cast_node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Reorder " << transdata_node->fullname_with_scope() << ", " << cast_node->fullname_with_scope();
|
|
||||||
|
|
||||||
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), transdata_node->inputs()[1]});
|
|
||||||
SetNodeInfo(transdata_node, cast_node, new_cast_node);
|
|
||||||
|
|
||||||
auto new_transdata_node = func_graph->NewCNode({NewValueNode(prim::KPrimTransData), new_cast_node});
|
|
||||||
SetNodeInfo(transdata_node, cast_node, new_transdata_node);
|
|
||||||
|
|
||||||
(void)mng->Replace(cast_node, new_transdata_node);
|
|
||||||
changed = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
bool ReorderOps::Run(const FuncGraphPtr &func_graph) { return ReorderTransDataCast(func_graph); }
|
bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph, true);
|
||||||
|
func_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
for (const auto &anf_node : todos) {
|
||||||
|
auto node = anf_node->cast<CNodePtr>();
|
||||||
|
if (node == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||||
|
bool need_traverse = true;
|
||||||
|
while (need_traverse) {
|
||||||
|
need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
|
||||||
|
if (need_traverse) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REORDER_OPS_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REORDER_OPS_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include "backend/optimizer/common/pass.h"
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -27,6 +28,16 @@ class ReorderOps : public Pass {
|
||||||
ReorderOps() : Pass("reorder_ops") {}
|
ReorderOps() : Pass("reorder_ops") {}
|
||||||
~ReorderOps() override = default;
|
~ReorderOps() override = default;
|
||||||
bool Run(const FuncGraphPtr &func_graph) override;
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
|
||||||
|
const std::vector<AnfNodePtr> &new_input_in_indexes,
|
||||||
|
std::vector<AnfNodePtr> *new_inputs);
|
||||||
|
bool ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||||
|
const CNodePtr &node);
|
||||||
|
bool ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
||||||
|
const CNodePtr &node);
|
||||||
|
bool ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph);
|
||||||
};
|
};
|
||||||
using ReorderOpsPtr = std::shared_ptr<ReorderOps>;
|
using ReorderOpsPtr = std::shared_ptr<ReorderOps>;
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class CastUpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CastUpNet, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
|
||||||
|
def construct(self, i0):
|
||||||
|
res = self.cast(i0, mstype.float32)
|
||||||
|
res = self.transpose(res, (1, 0))
|
||||||
|
res = self.neg(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_castup_output(x0, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net = CastUpNet()
|
||||||
|
output = net(x0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def test_castup():
|
||||||
|
x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float16))
|
||||||
|
expect = get_castup_output(x0, False)
|
||||||
|
output = get_castup_output(x0, True)
|
||||||
|
expect_np = expect.asnumpy().copy()
|
||||||
|
output_np = output.asnumpy().copy()
|
||||||
|
assert np.allclose(expect_np, output_np, 1e-4, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class CastDownNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CastDownNet, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
|
||||||
|
def construct(self, i0):
|
||||||
|
res = self.transpose(i0, (1, 0))
|
||||||
|
res = self.neg(res)
|
||||||
|
res = self.cast(res, mstype.float16)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_castdown_output(x0, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net = CastDownNet()
|
||||||
|
output = net(x0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def test_castdown():
|
||||||
|
x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float32))
|
||||||
|
expect = get_castdown_output(x0, False)
|
||||||
|
output = get_castdown_output(x0, True)
|
||||||
|
expect_np = expect.asnumpy().copy()
|
||||||
|
output_np = output.asnumpy().copy()
|
||||||
|
assert np.allclose(expect_np, output_np, 1e-3, 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_castup_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_castup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_castup_ascend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_castup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_castdown_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_castdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_castdown_ascend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_castdown()
|
Loading…
Reference in New Issue