use original axis attr

This commit is contained in:
yuchaojie 2023-03-06 17:26:15 +08:00
parent 232aede207
commit b86e72e0f4
4 changed files with 30 additions and 19 deletions

View File

@ -167,18 +167,15 @@ const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const A
if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) {
return nullptr;
}
NormalizeReduceAttrAxis(node->cast<CNodePtr>());
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
if (convert_map == kReduceConvertMap.end()) {
if (common::AnfAlgo::IsDynamicShape(node)) {
DynamicAttrUpdate(node);
}
return nullptr;
}
convert_map->second(node->cast<CNodePtr>());
if (common::AnfAlgo::IsDynamicShape(node)) {
DynamicAttrUpdate(node);
}
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
if (convert_map == kReduceConvertMap.end()) {
return nullptr;
}
NormalizeReduceAttrAxis(node->cast<CNodePtr>());
convert_map->second(node->cast<CNodePtr>());
return nullptr;
}
} // namespace mindspore::opt

View File

@ -103,6 +103,7 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func
std::vector<int64_t> axis;
if (axis_value->isa<ValueSequence>()) {
axis = GetValue<std::vector<int64_t>>(axis_value);
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
if (axis.empty()) { // reduce_sum for all dimensions
for (size_t i = 0; i < dim; ++i) {
(void)axis.emplace_back(i);
@ -110,11 +111,11 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func
}
} else if (axis_value->isa<Int64Imm>()) {
(void)axis.emplace_back(GetValue<int64_t>(axis_value));
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
} else {
MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name()
<< "`, the type of attribute `axis` is invalid.";
}
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
// Set abstract to `reduce_sum` op
int64_t ddim = SizeToLong(dim);
ShapeVector reduce_sum_output_shape = shape_vec;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,10 +30,6 @@ namespace {
constexpr size_t kCNodePrimitiveIdx = 0;
constexpr size_t kAllToAllInputIdx = 1;
inline int64_t NormalizeDim(const ShapeVector &shape, int64_t dim) {
return dim < 0 ? SizeToLong(shape.size()) + dim : dim;
}
void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto neighbor_exchange = node->cast<CNodePtr>();
@ -75,8 +71,8 @@ CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const FuncGraphPtr &graph, const C
MS_EXCEPTION_IF_NULL(split_v);
auto dtype = common::AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
auto shape = common::AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
split_dim = NormalizeDim(shape, split_dim);
if (SizeToLong(shape.size()) <= split_dim) {
auto shape_size = SizeToLong(shape.size());
if (split_dim >= shape_size || split_dim < -shape_size) {
MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size()
<< trace::DumpSourceLines(all_to_all);
}
@ -163,8 +159,8 @@ CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const
auto concat = NewCNode(concat_input, graph);
MS_EXCEPTION_IF_NULL(concat);
auto single_shape = common::AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
concat_dim = NormalizeDim(single_shape, concat_dim);
if (LongToSize(concat_dim) >= single_shape.size()) {
auto shape_size = SizeToLong(single_shape.size());
if (concat_dim >= shape_size || concat_dim < -shape_size) {
MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size()
<< trace::DumpSourceLines(all_to_all);
}

View File

@ -154,3 +154,20 @@ def test_dynamic_axis_reduce(data_type):
dyn_axis_case(data_type)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
dyn_axis_case(data_type)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("data_type", [np.float32])
def test_dynamic_axis_reduce_ascend(data_type):
"""
Feature: Reduce DynamicShape.
Description: Test case of dynamic shape for reduce operator with dynamic axis in Ascend.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dyn_axis_case(data_type)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dyn_axis_case(data_type)