From cb05c9b8150a4d80b438551c007ddaf16d5b5aea Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Fri, 19 Apr 2024 16:08:13 +0800 Subject: [PATCH] Add pass for dyn paddings in PadV3 --- .../ccsrc/backend/common/optimizer/helper.cc | 26 ++- .../optimizer/ge/convert_pad_v3_paddings.cc | 206 ++++++++++++++++-- .../optimizer/ge/convert_pad_v3_paddings.h | 50 ++++- mindspore/core/ops/pad_v3.cc | 29 +++ tests/st/ops/ascend/test_pad_v3.py | 5 +- 5 files changed, 272 insertions(+), 44 deletions(-) diff --git a/mindspore/ccsrc/backend/common/optimizer/helper.cc b/mindspore/ccsrc/backend/common/optimizer/helper.cc index 3f6480a9943..4b871f85ca0 100644 --- a/mindspore/ccsrc/backend/common/optimizer/helper.cc +++ b/mindspore/ccsrc/backend/common/optimizer/helper.cc @@ -1562,18 +1562,24 @@ AnfNodePtr CreateValueNodeWithKernelInfo(const FuncGraphPtr &graph, const ValueP value_node->set_kernel_info(kernel_info); kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; builder.SetOutputsFormat({kOpFormat_DEFAULT}); - MS_EXCEPTION_IF_NULL(value->type()); - auto type_id = value->type()->type_id(); - if (value->isa()) { - auto value_sequence = value->cast()->value(); - if (value_sequence.empty()) { - type_id = kNumberTypeInt64; - } else { - MS_EXCEPTION_IF_NULL(value_sequence[0]->type()); - type_id = value_sequence[0]->type()->type_id(); + if (value->isa()) { + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + builder.SetOutputsDeviceType({tensor->data_type()}); + } else { + MS_EXCEPTION_IF_NULL(value->type()); + auto type_id = value->type()->type_id(); + if (value->isa()) { + auto value_sequence = value->cast()->value(); + if (value_sequence.empty()) { + type_id = kNumberTypeInt64; + } else { + MS_EXCEPTION_IF_NULL(value_sequence[0]->type()); + type_id = value_sequence[0]->type()->type_id(); + } } + builder.SetOutputsDeviceType({type_id}); } - builder.SetOutputsDeviceType({type_id}); auto object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(value_abs)); builder.SetOutputsKernelObjectType({object_type}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get()); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.cc index 2240c1bd00d..ecaf1560036 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.cc @@ -24,6 +24,8 @@ #include "include/common/utils/anfalgo.h" #include "include/backend/anf_runtime_algorithm.h" #include "mindspore/core/ops/array_op_name.h" +#include "mindspore/core/ops/sequence_op_name.h" +#include "mindspore/core/ops/auto_generate/gen_ops_name.h" namespace mindspore { namespace opt { @@ -37,9 +39,6 @@ bool ConvertBasePaddings::HasDynPaddings(const CNodePtr &cnode) const { MS_EXCEPTION_IF_NULL(paddings_abstract); auto paddings_value = paddings_abstract->GetValue(); MS_EXCEPTION_IF_NULL(paddings_value); - if (paddings_value->isa() || paddings_value->isa()) { - return true; - } auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, kIndex1); if (input_paddings_type_id == kNumberTypeInt32) { auto paddings_array_value = ops::GetArrayValue(paddings_value); @@ -64,6 +63,146 @@ const CNodePtr ConvertBasePaddings::CreateReshapeNode(const FuncGraphPtr &graph, return reshape_node; } +const CNodePtr ConvertBasePaddings::CreateStridedSliceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + int64_t index) const { + // set inputs + auto begin_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector{index})); + MS_EXCEPTION_IF_NULL(begin_node); + auto end_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector{index + 1})); + MS_EXCEPTION_IF_NULL(end_node); + auto strides_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector{1})); + MS_EXCEPTION_IF_NULL(strides_node); + int64_t const_value = 0; + auto begin_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value)); + MS_EXCEPTION_IF_NULL(begin_mask); + auto end_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value)); + MS_EXCEPTION_IF_NULL(end_mask); + auto ellipsis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value)); + MS_EXCEPTION_IF_NULL(ellipsis_mask); + auto new_axis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value)); + MS_EXCEPTION_IF_NULL(new_axis_mask); + auto shrink_axis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value)); + MS_EXCEPTION_IF_NULL(shrink_axis_mask); + + auto prim = std::make_shared(kStridedSliceOpName); + MS_EXCEPTION_IF_NULL(prim); + AnfNodePtrList inputs = {NewValueNode(prim), input_node, begin_node, end_node, strides_node, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask}; + auto strided_slice_node = NewCNode(inputs, func_graph); + MS_EXCEPTION_IF_NULL(strided_slice_node); + auto abs = InferAbstract(prim, {input_node, begin_node, end_node, strides_node, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask}); + MS_EXCEPTION_IF_NULL(abs); + strided_slice_node->set_abstract(abs); + static size_t slice_index = 0; + strided_slice_node->set_fullname_with_scope(input_node->fullname_with_scope() + "_strided_slice_" + + std::to_string(slice_index++)); + return strided_slice_node; +} + +const CNodePtr ConvertBasePaddings::CreateConcatNode(const FuncGraphPtr &func_graph, + const std::vector &concat_input_vec, + const std::string &concat_node_name) const { + auto concat_prim = std::make_shared(kConcatOpName); + MS_EXCEPTION_IF_NULL(concat_prim); + std::vector dyn_input_sizes = {SizeToLong(concat_input_vec.size()), -1}; + concat_prim->AddAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes)); + + AnfNodePtrList inputs = {NewValueNode(concat_prim)}; + inputs.insert(inputs.end(), concat_input_vec.begin(), concat_input_vec.end()); + int64_t axis = 0; + auto axis_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(axis)); + inputs.push_back(axis_node); + auto concat_node = NewCNode(inputs, func_graph); + MS_EXCEPTION_IF_NULL(concat_node); + + std::vector concat_inputs = concat_input_vec; + concat_inputs.push_back(axis_node); + auto concat_abs = InferAbstract(concat_prim, concat_inputs); + MS_EXCEPTION_IF_NULL(concat_abs); + concat_node->set_abstract(concat_abs); + concat_node->set_fullname_with_scope(concat_node_name); + return concat_node; +} + +const CNodePtr ConvertBasePaddings::ProcessSliceNConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &pad_node, + const AnfNodePtr &input_node, const int64_t &padding_dst_length, + const int64_t &padding_src_length) const { + auto prim = GetCNodePrimitive(pad_node); + MS_EXCEPTION_IF_NULL(prim); + auto paddings_contiguous = GetValue(prim->GetAttr("paddings_contiguous")); + std::vector concat_input_vec; + + // slice and insert to concat in reverse order + if (paddings_contiguous) { + for (int64_t i = 0; i < padding_src_length; i += static_cast(kSizeTwo)) { + auto slice_node_2 = CreateStridedSliceNode(func_graph, input_node, i + kSizeOne); + concat_input_vec.insert(concat_input_vec.begin(), slice_node_2); + + auto slice_node_1 = CreateStridedSliceNode(func_graph, input_node, i); + concat_input_vec.insert(concat_input_vec.begin(), slice_node_1); + } + } else { + for (int64_t i = 0; i < padding_src_length / 2; ++i) { + auto slice_node_2 = CreateStridedSliceNode(func_graph, input_node, i + padding_src_length / 2); + concat_input_vec.insert(concat_input_vec.begin(), slice_node_2); + + auto slice_node_1 = CreateStridedSliceNode(func_graph, input_node, i); + concat_input_vec.insert(concat_input_vec.begin(), slice_node_1); + } + prim->AddAttr("paddings_contiguous", MakeValue(True)); + } + + if (padding_dst_length > padding_src_length) { + auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(pad_node, kIndex1); + std::shared_ptr fill_tensor; + if (input_paddings_type_id == kNumberTypeInt32) { + fill_tensor = + std::make_shared(std::vector(padding_dst_length - padding_src_length, 0), kInt32); + } else if (input_paddings_type_id == kNumberTypeInt64) { + fill_tensor = + std::make_shared(std::vector(padding_dst_length - padding_src_length, 0), kInt64); + } else { + MS_LOG_EXCEPTION << "Unsupported data type for PadV3 padddings input."; + } + MS_EXCEPTION_IF_NULL(fill_tensor); + auto fill_node = CreateValueNodeWithKernelInfo(func_graph, fill_tensor); + MS_EXCEPTION_IF_NULL(fill_node); + concat_input_vec.insert(concat_input_vec.begin(), fill_node); + } + static size_t concat_index = 0; + auto concat_node = + CreateConcatNode(func_graph, concat_input_vec, + pad_node->fullname_with_scope() + "_pad_slice_concat" + std::to_string(concat_index++)); + return concat_node; +} + +const AnfNodePtr ConvertBasePaddings::CreateDynPaddingsPass(const FuncGraphPtr &graph, const CNodePtr &pad_node, + const bool &is_grad) const { + // For dyn paddings in PadV3 and PadV3Grad on Ascend, add StridedSlice -> Concat to adjust paddings in ge::PadV3. + auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(pad_node, kIndex0); + size_t dst_length = input_x_shape.size() * 2; + auto prim_name = "PadV3"; + if (is_grad) { + prim_name = "PadV3Grad"; + } + + auto paddings = common::AnfAlgo::GetInputNode(pad_node, kIndex1); + MS_EXCEPTION_IF_NULL(paddings); + auto paddings_abstract = paddings->abstract(); + MS_EXCEPTION_IF_NULL(paddings_abstract); + auto paddings_shape_ptr = paddings_abstract->GetShape(); + MS_EXCEPTION_IF_NULL(paddings_shape_ptr); + auto paddings_shape = paddings_shape_ptr->GetShapeVector(); + (void)CheckAndConvertUtils::CheckInteger("paddings_shape_size", SizeToLong(paddings_shape.size()), kEqual, kDim1, + prim_name); + auto paddings_length = paddings_shape[0]; + // Not implemented: if is_grad and dst_length < 8, the filled paddings should be expanded to 8. + auto concat_node = ProcessSliceNConcat(graph, pad_node, paddings, dst_length, paddings_length); + MS_EXCEPTION_IF_NULL(concat_node); + return concat_node; +} + template const AnfNodePtr ConvertBasePaddings::OptimizePaddingsValue(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings, @@ -127,6 +266,25 @@ const AnfNodePtr ConvertBasePaddings::OptimizePaddingsValue(const FuncGraphPtr & return extend_paddings; } +const AnfNodePtr ConvertBasePaddings::CreateConstPaddingsNode(const FuncGraphPtr &graph, + const CNodePtr &pad_node) const { + auto prim = GetCNodePrimitive(pad_node); + MS_EXCEPTION_IF_NULL(prim); + auto paddings_contiguous = GetValue(prim->GetAttr("paddings_contiguous")); + // ge::padV3 only support that the length of `paddings` is twice than the rank of `x` + auto input_paddings = common::AnfAlgo::GetInputNode(pad_node, kIndex1); + MS_EXCEPTION_IF_NULL(input_paddings); + auto paddings_abstract = input_paddings->abstract(); + MS_EXCEPTION_IF_NULL(paddings_abstract); + + auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(pad_node, kIndex0); + auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(pad_node, kIndex1); + auto paddings_value_node = CreateConstPaddingsPass(graph, paddings_abstract, paddings_contiguous, + input_x_shape.size() * 2, input_paddings_type_id); + MS_EXCEPTION_IF_NULL(paddings_value_node); + return paddings_value_node; +} + const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(graph); @@ -135,30 +293,24 @@ const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const A MS_EXCEPTION_IF_NULL(cnode); auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex0); - auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, kIndex1); - auto opt_paddings_size = 2 * input_x_shape.size(); + auto padding_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex1); + if (IsDynamicRank(input_x_shape) || IsDynamic(padding_shape)) { + MS_LOG_EXCEPTION << "The input is dynamic rank"; + } if (HasDynPaddings(cnode)) { - MS_EXCEPTION(TypeError) << "While running in Ascend, the input [paddings] of PadV3 is required to be constant, but " - "that is dynamic in node[" - << node->fullname_with_scope() << "]"; + auto concat_node = CreateDynPaddingsNode(graph, cnode); + MS_EXCEPTION_IF_NULL(concat_node); + auto node_prim = GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(node_prim); + node_prim->AddAttr("is_dyn_paddings", MakeValue(true)); + cnode->set_input(kIndex2, concat_node); } else { - auto prim = GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(prim); - auto paddings_contiguous = GetValue(prim->GetAttr("paddings_contiguous")); - // ge::padV3 only support that the length of `paddings` is twice than the rank of `x` - auto input_paddings = common::AnfAlgo::GetInputNode(cnode, kIndex1); - MS_EXCEPTION_IF_NULL(input_paddings); - auto paddings_abstract = input_paddings->abstract(); - MS_EXCEPTION_IF_NULL(paddings_abstract); - auto paddings_type = paddings_abstract->GetType(); - MS_EXCEPTION_IF_NULL(paddings_type); - - auto paddings_value_node = - CreatePaddingsNode(graph, paddings_abstract, paddings_contiguous, opt_paddings_size, input_paddings_type_id); + auto paddings_value_node = CreateConstPaddingsNode(graph, cnode); MS_EXCEPTION_IF_NULL(paddings_value_node); cnode->set_input(kIndex2, paddings_value_node); } + // Not verified: for PadV3Grad, if the input tensor rand < 4, the input should be expanded to 4. auto is_expand = ExpandInputXDims(graph, cnode); if (is_expand) { ReduceOutputDims(graph, cnode); @@ -166,6 +318,18 @@ const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const A return node; } +const AnfNodePtr ConvertPadV3GradPaddings::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (HasDynPaddings(cnode)) { + MS_EXCEPTION(RuntimeError) << "PadV3Grad doesn't support dynamic paddings input."; + } + return ConvertBasePaddings::Process(graph, node, equiv); +} + bool ConvertPadV3GradPaddings::ExpandInputXDims(const FuncGraphPtr &graph, const CNodePtr &node) const { auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex0); auto input_x_rank = input_x_shape.size(); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h index 788bbf8d33a..cfed92c0721 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_CONVERT_PAD_V3_PADDINGS_H_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_CONVERT_PAD_V3_PADDINGS_H_ +#include #include #include "include/backend/optimizer/optimizer.h" @@ -31,13 +32,25 @@ class ConvertBasePaddings : public PatternProcessPass { bool HasDynPaddings(const CNodePtr &) const; const CNodePtr CreateReshapeNode(const FuncGraphPtr &, const AnfNodePtr &, const ShapeVector &) const; + const CNodePtr CreateStridedSliceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + int64_t index) const; + const CNodePtr CreateConcatNode(const FuncGraphPtr &, const std::vector &, const std::string &) const; + const CNodePtr ProcessSliceNConcat(const FuncGraphPtr &, const AnfNodePtr &, const AnfNodePtr &, const int64_t &, + const int64_t &) const; + + const AnfNodePtr CreateDynPaddingsPass(const FuncGraphPtr &, const CNodePtr &, const bool &) const; + virtual const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &, const CNodePtr &) const { return nullptr; } + template const AnfNodePtr OptimizePaddingsValue(const FuncGraphPtr &, const AbstractBasePtr &, const bool &, const size_t &, bool force_length8) const; - virtual const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &, const AbstractBasePtr &, const bool &, - const size_t &, const TypeId &) const { + virtual const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &, const AbstractBasePtr &, const bool &, + const size_t &, const TypeId &) const { return nullptr; } + const AnfNodePtr CreateConstPaddingsNode(const FuncGraphPtr &, const CNodePtr &) const; + + private: virtual bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const { return false; } virtual void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const {} }; @@ -50,15 +63,21 @@ class ConvertPadV3Paddings : public ConvertBasePaddings { const BaseRef DefinePattern() const override; private: - const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings, - const bool &paddings_contiguous, const size_t &dst_length, - const TypeId &type_id) const override { + const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings, + const bool &paddings_contiguous, const size_t &dst_length, + const TypeId &type_id) const override { if (type_id == kNumberTypeInt32) { return ConvertBasePaddings::OptimizePaddingsValue( graph, ori_paddings, paddings_contiguous, dst_length, false); + } else if (type_id == kNumberTypeInt64) { + return ConvertBasePaddings::OptimizePaddingsValue( + graph, ori_paddings, paddings_contiguous, dst_length, false); + } else { + MS_LOG_EXCEPTION << "Unsupported data type for PadV3 paddings input."; } - return ConvertBasePaddings::OptimizePaddingsValue( - graph, ori_paddings, paddings_contiguous, dst_length, false); + } + const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &graph, const CNodePtr &pad_node) const override { + return ConvertBasePaddings::CreateDynPaddingsPass(graph, pad_node, false); } bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const override { return false; } void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const override {} @@ -70,17 +89,24 @@ class ConvertPadV3GradPaddings : public ConvertBasePaddings { : ConvertBasePaddings("convert_pad_v3_grad_paddings", multi_graph) {} ~ConvertPadV3GradPaddings() override = default; const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings, - const bool &paddings_contiguous, const size_t &dst_length, - const TypeId &type_id) const override { + const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings, + const bool &paddings_contiguous, const size_t &dst_length, + const TypeId &type_id) const override { if (type_id == kNumberTypeInt32) { return ConvertBasePaddings::OptimizePaddingsValue( graph, ori_paddings, paddings_contiguous, dst_length, true); + } else if (type_id == kNumberTypeInt64) { + return ConvertBasePaddings::OptimizePaddingsValue( + graph, ori_paddings, paddings_contiguous, dst_length, true); + } else { + MS_LOG_EXCEPTION << "Unsupported data type for PadV3Grad paddings input."; } - return ConvertBasePaddings::OptimizePaddingsValue(graph, ori_paddings, - paddings_contiguous, dst_length, true); + } + const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &graph, const CNodePtr &pad_node) const override { + return ConvertBasePaddings::CreateDynPaddingsPass(graph, pad_node, true); } bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const override; void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const override; diff --git a/mindspore/core/ops/pad_v3.cc b/mindspore/core/ops/pad_v3.cc index 8f522dde0ed..41b885e8180 100644 --- a/mindspore/core/ops/pad_v3.cc +++ b/mindspore/core/ops/pad_v3.cc @@ -60,6 +60,17 @@ void PaddingsSizeCheck(const PrimitivePtr &primitive, const int64_t paddings_siz constexpr int64_t nFour = 4; constexpr int64_t nFive = 5; auto prim_name = primitive->name(); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + if (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { + auto is_dyn_paddings = primitive->GetAttr("is_dyn_paddings"); + if (is_dyn_paddings != nullptr && GetValue(is_dyn_paddings)) { + if (paddings_size / nTwo != size) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings length must be equal to " << size * nTwo; + } + return; + } + } auto mode = GetValue(primitive->GetAttr("mode")); if (mode == kConstant) { if (paddings_size / nTwo > size) { @@ -162,6 +173,23 @@ void CheckAscendInputXDim(const size_t &x_dim, const std::string &prim_name) { } } +void AscendTransformPaddingsAttr(const PrimitivePtr &primitive, + std::vector> *ori_paddings_attr) { + // If the `paddings` comes from the node added by pass, there are two features as followed: + // 1. the length of `paddings` is twice than the rank of `x`. + // 2. the mapper between `x` and `paddings` is lower to lower, + // which is different from that in another backends, which is lower to higher. + // So, the transform should be activated only where the `paddings` is from the node added by pass. + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + if (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { + auto is_dyn_paddings = primitive->GetAttr("is_dyn_paddings"); + if (is_dyn_paddings != nullptr && GetValue(is_dyn_paddings)) { + std::reverse(ori_paddings_attr->begin(), ori_paddings_attr->end()); + } + } +} + abstract::ShapePtr PadV3InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { constexpr int64_t kEdgeMaxDims = 5; constexpr int64_t kOtherMinDims = 3; @@ -242,6 +270,7 @@ abstract::ShapePtr PadV3InferShape(const PrimitivePtr &primitive, const std::vec std::make_pair(paddings_val[LongToSize(nTwo * i)], paddings_val[LongToSize(nTwo * i + 1)])); } } + AscendTransformPaddingsAttr(primitive, &paddings_attr); std::vector out_shape; for (int64_t i = 0; i < size; ++i) { int64_t now_dim_size = x_shape[LongToSize(i)] + paddings_attr[LongToSize(size - i - 1)].first + diff --git a/tests/st/ops/ascend/test_pad_v3.py b/tests/st/ops/ascend/test_pad_v3.py index 337ec513bb5..7c7a8ef632c 100644 --- a/tests/st/ops/ascend/test_pad_v3.py +++ b/tests/st/ops/ascend/test_pad_v3.py @@ -184,7 +184,8 @@ def test_padv3_constant_shape_4d(x_data_type, mode, ms_mode): @pytest.mark.parametrize('x_data_type', [np.int16, np.float32]) @pytest.mark.parametrize('mode', ["constant", "edge"]) @pytest.mark.parametrize('ms_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode): +@pytest.mark.parametrize('is_dyn_paddings', [True, False]) +def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode, is_dyn_paddings): """ Feature: test padv3 x and const shape paddings Description: test padv3 with const shape paddings @@ -193,6 +194,8 @@ def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode): context.set_context(mode=ms_mode, device_target="Ascend") x = Tensor(np.arange(18).reshape(1, 1, 2, 3, 3).astype(x_data_type)) paddings = (1, 2, 1, 1, 0, 1) + if is_dyn_paddings: + paddings = Tensor(paddings, dtype=ms.int64) value = None if mode == "constant": value = 99 if x_data_type == np.int16 else 99.0