!49412 Replace TransData With Transpose

Merge pull request !49412 from jiaorui/replace-transdata-by-transpose
This commit is contained in:
i-robot 2023-03-08 01:22:26 +00:00 committed by Gitee
commit 7a1dc8e325
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 201 additions and 0 deletions

View File

@ -1146,6 +1146,7 @@ constexpr auto kAttrPreKernelGraph = "pre_kernel_graph";
constexpr auto kAttrNeedInline = "need_inline";
constexpr auto kAttrOriFusionName = "ori_fusion_name";
constexpr auto kAttrDynamicLenName = "is_dynamic_len";
constexpr auto kAttrForFormatChange = "for_format_change";
// FuncGraph Flags
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";

View File

@ -188,6 +188,7 @@
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
#include "plugin/device/ascend/hal/common/ascend_utils.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#include "plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h"
namespace mindspore {
namespace opt {
@ -314,6 +315,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>());
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
data_layout_pm->AddPass(std::make_shared<ReplaceTransDataWithTranspose>());
if (kernel_graph->has_flag(kFlagPyNativeRunInGraph)) {
data_layout_pm->AddPass(std::make_shared<EliminateGraphOutputTransdata>());
}

View File

@ -0,0 +1,155 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h"
#include <string>
#include <memory>
#include <set>
#include <vector>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/helper.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
const size_t kTransDataInputIndex = 1;
bool CheckTransDataSupport(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!common::AnfAlgo::IsDynamicShape(node)) {
return true;
}
static std::set<string> format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_CHWN, kOpFormat_HWCN,
kOpFormat_NHWC};
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(node, 0);
auto output_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
return format_list.find(input_format) == format_list.end() || format_list.find(output_format) == format_list.end() ||
input_shape.size() != kDim4 || output_shape.size() != kDim4;
}
std::vector<int64_t> GetTransposePerm(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto input_format = AnfAlgo::GetInputFormat(node, 0);
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW;
}
if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW;
}
std::vector<int64_t> perm_value;
for (size_t i = 0; i < output_format.size(); i++) {
auto index = input_format.find((output_format[i]));
if (index == std::string::npos) {
MS_LOG(EXCEPTION) << "Can not find output dim [" << output_format[i] << "] in input format [" << input_format
<< "].";
}
perm_value.emplace_back(index);
}
return perm_value;
}
ValueNodePtr CreatePermValueNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto perm = GetTransposePerm(node);
auto perm_value = std::make_shared<tensor::Tensor>(perm, kInt64);
auto perm_node = NewValueNode(perm_value);
MS_EXCEPTION_IF_NULL(perm_node);
auto value_abstract = perm_value->ToAbstract();
perm_node->set_abstract(value_abstract);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
perm_node->set_kernel_info(kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({kNumberTypeInt64});
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), perm_node.get());
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->AddValueNodeToGraph(perm_node);
return perm_node;
}
void SetKernelBuildInfo(const AnfNodePtr &transdata, const AnfNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(transpose);
MS_EXCEPTION_IF_NULL(transdata);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetInputsFormat({AnfAlgo::GetInputFormat(transdata, 0), kOpFormat_DEFAULT});
selected_kernel_builder.SetInputsDeviceType(
{AnfAlgo::GetInputDeviceDataType(transdata, 0), TypeId::kNumberTypeInt64});
selected_kernel_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(transdata, 0)});
selected_kernel_builder.SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(transdata, 0)});
selected_kernel_builder.SetInputsKernelObjectType(
{kernel::KernelObjectType::TENSOR, kernel::KernelObjectType::TENSOR});
selected_kernel_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
selected_kernel_builder.SetInputsReshapeType({"", ""});
selected_kernel_builder.SetOutputsReshapeType({""});
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), transpose.get());
}
AnfNodePtr CreateNewTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto perm_node = CreatePermValueNode(func_graph, node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)),
cnode->input(kTransDataInputIndex), perm_node};
auto transpose = pass.NewCNode(inputs, func_graph);
MS_EXCEPTION_IF_NULL(transpose);
SetKernelBuildInfo(node, transpose);
std::vector<std::string> input_names = {"x", "perm"};
std::vector<std::string> output_names = {"output"};
common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
common::AnfAlgo::SetNodeAttr(kAttrForFormatChange, MakeValue(true), transpose);
transpose->set_abstract(node->abstract());
transpose->set_scope(node->scope());
return transpose;
}
} // namespace
const BaseRef ReplaceTransDataWithTranspose::DefinePattern() const {
VarPtr x1 = std::make_shared<Var>();
return VectorRef({prim::kPrimTransData, x1});
}
const AnfNodePtr ReplaceTransDataWithTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope();
if (CheckTransDataSupport(node)) {
MS_LOG(DEBUG) << "TransData is support, no need replace. node: " << node->fullname_with_scope();
return nullptr;
}
auto transpose = CreateNewTranspose(func_graph, node, *this);
MS_EXCEPTION_IF_NULL(transpose);
MS_LOG(DEBUG) << "TransData is not supported from input format " << AnfAlgo::GetInputFormat(node, 0)
<< " to output format " << AnfAlgo::GetOutputFormat(node, 0)
<< " in dynamic shape scenario, replace TransData with Transpose."
<< " Origin TransData: " << node->fullname_with_scope()
<< ", New Transpose: " << transpose->fullname_with_scope();
return transpose;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H
#include <string>
#include <memory>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/helper.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
class ReplaceTransDataWithTranspose : public PatternProcessPass {
public:
explicit ReplaceTransDataWithTranspose(bool multigraph = true)
: PatternProcessPass("replace_transdata_with_transpose", multigraph) {}
~ReplaceTransDataWithTranspose() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H

View File

@ -82,6 +82,11 @@ class TransposeInfer : public abstract::OpInferBase {
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
auto for_format_change_value = primitive->GetAttr(kAttrForFormatChange);
if (for_format_change_value != nullptr && GetValue<bool>(for_format_change_value)) {
return std::make_shared<abstract::Shape>(x_shape);
}
ShapeVector p_value;
if (x_shape[0] == 0) {
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";