diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index d60fc56d512..528cde2d538 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -1137,6 +1137,7 @@ constexpr auto kAttrPreKernelGraph = "pre_kernel_graph"; constexpr auto kAttrNeedInline = "need_inline"; constexpr auto kAttrOriFusionName = "ori_fusion_name"; constexpr auto kAttrDynamicLenName = "is_dynamic_len"; +constexpr auto kAttrForFormatChange = "for_format_change"; // FuncGraph Flags constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph"; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc index cb6d2659caf..aab89d19134 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc @@ -187,6 +187,7 @@ #include "plugin/device/ascend/optimizer/optimizer_factory.h" #include "plugin/device/ascend/hal/common/ascend_utils.h" #include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h" +#include "plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h" namespace mindspore { namespace opt { @@ -313,6 +314,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); if (kernel_graph->has_flag(kFlagPyNativeRunInGraph)) { data_layout_pm->AddPass(std::make_shared()); } diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.cc new file mode 100644 index 00000000000..50f1ce4b192 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h" +#include +#include +#include +#include +#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/helper.h" +#include "plugin/device/ascend/optimizer/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kTransDataInputIndex = 1; + +bool CheckTransDataSupport(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!common::AnfAlgo::IsDynamicShape(node)) { + return true; + } + static std::set format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_CHWN, kOpFormat_HWCN, + kOpFormat_NHWC}; + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(node, 0); + auto output_shape = AnfAlgo::GetOutputDeviceShape(node, 0); + return format_list.find(input_format) == format_list.end() || format_list.find(output_format) == format_list.end() || + input_shape.size() != kDim4 || output_shape.size() != kDim4; +} + +std::vector GetTransposePerm(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + if (input_format == kOpFormat_DEFAULT) { + input_format = kOpFormat_NCHW; + } + if (output_format == kOpFormat_DEFAULT) { + output_format = kOpFormat_NCHW; + } + std::vector perm_value; + for (size_t i = 0; i < output_format.size(); i++) { + auto index = input_format.find((output_format[i])); + if (index == std::string::npos) { + MS_LOG(EXCEPTION) << "Can not find output dim [" << output_format[i] << "] in input format [" << input_format + << "]."; + } + perm_value.emplace_back(index); + } + return perm_value; +} + +ValueNodePtr CreatePermValueNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto perm = GetTransposePerm(node); + auto perm_value = std::make_shared(perm, kInt64); + auto perm_node = NewValueNode(perm_value); + MS_EXCEPTION_IF_NULL(perm_node); + auto value_abstract = perm_value->ToAbstract(); + perm_node->set_abstract(value_abstract); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + perm_node->set_kernel_info(kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType({kNumberTypeInt64}); + builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), perm_node.get()); + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->AddValueNodeToGraph(perm_node); + return perm_node; +} + +void SetKernelBuildInfo(const AnfNodePtr &transdata, const AnfNodePtr &transpose) { + MS_EXCEPTION_IF_NULL(transpose); + MS_EXCEPTION_IF_NULL(transdata); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({AnfAlgo::GetInputFormat(transdata, 0), kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType( + {AnfAlgo::GetInputDeviceDataType(transdata, 0), TypeId::kNumberTypeInt64}); + selected_kernel_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(transdata, 0)}); + selected_kernel_builder.SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(transdata, 0)}); + selected_kernel_builder.SetInputsKernelObjectType( + {kernel::KernelObjectType::TENSOR, kernel::KernelObjectType::TENSOR}); + selected_kernel_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR}); + selected_kernel_builder.SetInputsReshapeType({"", ""}); + selected_kernel_builder.SetOutputsReshapeType({""}); + selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); + selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), transpose.get()); +} + +AnfNodePtr CreateNewTranspose(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const PatternProcessPass &pass) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto perm_node = CreatePermValueNode(func_graph, node); + std::vector inputs = {NewValueNode(std::make_shared(kTransposeOpName)), + cnode->input(kTransDataInputIndex), perm_node}; + auto transpose = pass.NewCNode(inputs, func_graph); + MS_EXCEPTION_IF_NULL(transpose); + SetKernelBuildInfo(node, transpose); + std::vector input_names = {"x", "perm"}; + std::vector output_names = {"output"}; + common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose); + common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose); + common::AnfAlgo::SetNodeAttr(kAttrForFormatChange, MakeValue(true), transpose); + transpose->set_abstract(node->abstract()); + transpose->set_scope(node->scope()); + return transpose; +} +} // namespace + +const BaseRef ReplaceTransDataWithTranspose::DefinePattern() const { + VarPtr x1 = std::make_shared(); + return VectorRef({prim::kPrimTransData, x1}); +} + +const AnfNodePtr ReplaceTransDataWithTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope(); + if (CheckTransDataSupport(node)) { + MS_LOG(DEBUG) << "TransData is support, no need replace. node: " << node->fullname_with_scope(); + return nullptr; + } + auto transpose = CreateNewTranspose(func_graph, node, *this); + MS_EXCEPTION_IF_NULL(transpose); + MS_LOG(DEBUG) << "TransData is not supported from input format " << AnfAlgo::GetInputFormat(node, 0) + << " to output format " << AnfAlgo::GetOutputFormat(node, 0) + << " in dynamic shape scenario, replace TransData with Transpose." + << " Origin TransData: " << node->fullname_with_scope() + << ", New Transpose: " << transpose->fullname_with_scope(); + return transpose; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h new file mode 100644 index 00000000000..2f680ad3f4a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/replace_transdata_with_transpose.h @@ -0,0 +1,38 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H + +#include +#include +#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/helper.h" +#include "plugin/device/ascend/optimizer/ascend_helper.h" + +namespace mindspore { +namespace opt { +class ReplaceTransDataWithTranspose : public PatternProcessPass { + public: + explicit ReplaceTransDataWithTranspose(bool multigraph = true) + : PatternProcessPass("replace_transdata_with_transpose", multigraph) {} + ~ReplaceTransDataWithTranspose() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_REPLACE_TRANSDATA_WITH_TRANSPOSE_H diff --git a/mindspore/core/ops/transpose.cc b/mindspore/core/ops/transpose.cc index 43b9513217d..4d1ff9d0b66 100644 --- a/mindspore/core/ops/transpose.cc +++ b/mindspore/core/ops/transpose.cc @@ -82,6 +82,11 @@ class TransposeInfer : public abstract::OpInferBase { CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, op_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; (void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name); + + auto for_format_change_value = primitive->GetAttr(kAttrForFormatChange); + if (for_format_change_value != nullptr && GetValue(for_format_change_value)) { + return std::make_shared(x_shape); + } ShapeVector p_value; if (x_shape[0] == 0) { MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";