From b86e72e0f491bcf9309ae0e800b57e0fe37901b7 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 6 Mar 2023 17:26:15 +0800 Subject: [PATCH] use original axis attr --- .../format_type/change_axis_of_reduce_kernel.cc | 15 ++++++--------- .../ir_fission/ascend_clip_by_norm_fission.cc | 3 ++- .../optimizer/mindir/all_to_all_unify_mindir.cc | 14 +++++--------- .../dynamic_shape/test_dynamic_reduce_ops.py | 17 +++++++++++++++++ 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/change_axis_of_reduce_kernel.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/change_axis_of_reduce_kernel.cc index 4470b6d2962..52a677bd1e2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/change_axis_of_reduce_kernel.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/change_axis_of_reduce_kernel.cc @@ -167,18 +167,15 @@ const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const A if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) { return nullptr; } - NormalizeReduceAttrAxis(node->cast()); - auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); - if (convert_map == kReduceConvertMap.end()) { - if (common::AnfAlgo::IsDynamicShape(node)) { - DynamicAttrUpdate(node); - } - return nullptr; - } - convert_map->second(node->cast()); if (common::AnfAlgo::IsDynamicShape(node)) { DynamicAttrUpdate(node); } + auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); + if (convert_map == kReduceConvertMap.end()) { + return nullptr; + } + NormalizeReduceAttrAxis(node->cast()); + convert_map->second(node->cast()); return nullptr; } } // namespace mindspore::opt diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.cc index 89142242e94..9c53b7542f4 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.cc @@ -103,6 +103,7 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func std::vector axis; if (axis_value->isa()) { axis = GetValue>(axis_value); + common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum); if (axis.empty()) { // reduce_sum for all dimensions for (size_t i = 0; i < dim; ++i) { (void)axis.emplace_back(i); @@ -110,11 +111,11 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func } } else if (axis_value->isa()) { (void)axis.emplace_back(GetValue(axis_value)); + common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum); } else { MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name() << "`, the type of attribute `axis` is invalid."; } - common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum); // Set abstract to `reduce_sum` op int64_t ddim = SizeToLong(dim); ShapeVector reduce_sum_output_shape = shape_vec; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc index 3b78be462cc..4702906905d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2021-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,10 +30,6 @@ namespace { constexpr size_t kCNodePrimitiveIdx = 0; constexpr size_t kAllToAllInputIdx = 1; -inline int64_t NormalizeDim(const ShapeVector &shape, int64_t dim) { - return dim < 0 ? SizeToLong(shape.size()) + dim : dim; -} - void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto neighbor_exchange = node->cast(); @@ -75,8 +71,8 @@ CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const FuncGraphPtr &graph, const C MS_EXCEPTION_IF_NULL(split_v); auto dtype = common::AnfAlgo::GetOutputInferDataType(all_to_all_input, 0); auto shape = common::AnfAlgo::GetOutputInferShape(all_to_all_input, 0); - split_dim = NormalizeDim(shape, split_dim); - if (SizeToLong(shape.size()) <= split_dim) { + auto shape_size = SizeToLong(shape.size()); + if (split_dim >= shape_size || split_dim < -shape_size) { MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size() << trace::DumpSourceLines(all_to_all); } @@ -163,8 +159,8 @@ CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const auto concat = NewCNode(concat_input, graph); MS_EXCEPTION_IF_NULL(concat); auto single_shape = common::AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0); - concat_dim = NormalizeDim(single_shape, concat_dim); - if (LongToSize(concat_dim) >= single_shape.size()) { + auto shape_size = SizeToLong(single_shape.size()); + if (concat_dim >= shape_size || concat_dim < -shape_size) { MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size() << trace::DumpSourceLines(all_to_all); } diff --git a/tests/st/ops/dynamic_shape/test_dynamic_reduce_ops.py b/tests/st/ops/dynamic_shape/test_dynamic_reduce_ops.py index 7f1001b5a22..de91bcc9a84 100644 --- a/tests/st/ops/dynamic_shape/test_dynamic_reduce_ops.py +++ b/tests/st/ops/dynamic_shape/test_dynamic_reduce_ops.py @@ -154,3 +154,20 @@ def test_dynamic_axis_reduce(data_type): dyn_axis_case(data_type) context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") dyn_axis_case(data_type) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("data_type", [np.float32]) +def test_dynamic_axis_reduce_ascend(data_type): + """ + Feature: Reduce DynamicShape. + Description: Test case of dynamic shape for reduce operator with dynamic axis in Ascend. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + dyn_axis_case(data_type) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + dyn_axis_case(data_type)