forked from mindspore-Ecosystem/mindspore
use original axis attr
This commit is contained in:
parent
232aede207
commit
b86e72e0f4
|
@ -167,18 +167,15 @@ const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const A
|
|||
if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) {
|
||||
return nullptr;
|
||||
}
|
||||
NormalizeReduceAttrAxis(node->cast<CNodePtr>());
|
||||
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
|
||||
if (convert_map == kReduceConvertMap.end()) {
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
DynamicAttrUpdate(node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
convert_map->second(node->cast<CNodePtr>());
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
DynamicAttrUpdate(node);
|
||||
}
|
||||
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
|
||||
if (convert_map == kReduceConvertMap.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
NormalizeReduceAttrAxis(node->cast<CNodePtr>());
|
||||
convert_map->second(node->cast<CNodePtr>());
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -103,6 +103,7 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func
|
|||
std::vector<int64_t> axis;
|
||||
if (axis_value->isa<ValueSequence>()) {
|
||||
axis = GetValue<std::vector<int64_t>>(axis_value);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
|
||||
if (axis.empty()) { // reduce_sum for all dimensions
|
||||
for (size_t i = 0; i < dim; ++i) {
|
||||
(void)axis.emplace_back(i);
|
||||
|
@ -110,11 +111,11 @@ AnfNodePtr AscendClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func
|
|||
}
|
||||
} else if (axis_value->isa<Int64Imm>()) {
|
||||
(void)axis.emplace_back(GetValue<int64_t>(axis_value));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name()
|
||||
<< "`, the type of attribute `axis` is invalid.";
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), reduce_sum);
|
||||
// Set abstract to `reduce_sum` op
|
||||
int64_t ddim = SizeToLong(dim);
|
||||
ShapeVector reduce_sum_output_shape = shape_vec;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,10 +30,6 @@ namespace {
|
|||
constexpr size_t kCNodePrimitiveIdx = 0;
|
||||
constexpr size_t kAllToAllInputIdx = 1;
|
||||
|
||||
inline int64_t NormalizeDim(const ShapeVector &shape, int64_t dim) {
|
||||
return dim < 0 ? SizeToLong(shape.size()) + dim : dim;
|
||||
}
|
||||
|
||||
void ChangePrimitiveToAllToAllV(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto neighbor_exchange = node->cast<CNodePtr>();
|
||||
|
@ -75,8 +71,8 @@ CNodePtr AllToAllUnifyMindIR::CreateSplitNode(const FuncGraphPtr &graph, const C
|
|||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
auto dtype = common::AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
|
||||
split_dim = NormalizeDim(shape, split_dim);
|
||||
if (SizeToLong(shape.size()) <= split_dim) {
|
||||
auto shape_size = SizeToLong(shape.size());
|
||||
if (split_dim >= shape_size || split_dim < -shape_size) {
|
||||
MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size()
|
||||
<< trace::DumpSourceLines(all_to_all);
|
||||
}
|
||||
|
@ -163,8 +159,8 @@ CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const
|
|||
auto concat = NewCNode(concat_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
auto single_shape = common::AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
|
||||
concat_dim = NormalizeDim(single_shape, concat_dim);
|
||||
if (LongToSize(concat_dim) >= single_shape.size()) {
|
||||
auto shape_size = SizeToLong(single_shape.size());
|
||||
if (concat_dim >= shape_size || concat_dim < -shape_size) {
|
||||
MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size()
|
||||
<< trace::DumpSourceLines(all_to_all);
|
||||
}
|
||||
|
|
|
@ -154,3 +154,20 @@ def test_dynamic_axis_reduce(data_type):
|
|||
dyn_axis_case(data_type)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
dyn_axis_case(data_type)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("data_type", [np.float32])
|
||||
def test_dynamic_axis_reduce_ascend(data_type):
|
||||
"""
|
||||
Feature: Reduce DynamicShape.
|
||||
Description: Test case of dynamic shape for reduce operator with dynamic axis in Ascend.
|
||||
Expectation: success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
dyn_axis_case(data_type)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
dyn_axis_case(data_type)
|
||||
|
|
Loading…
Reference in New Issue