!43523 fix Reshape dynamic shape infer shape and transpose_eliminate

Merge pull request !43523 from looop5/fix_reshape
This commit is contained in:
i-robot 2022-10-11 01:23:13 +00:00 committed by Gitee
commit cedd215b56
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 11 additions and 4 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,6 +42,9 @@ class TransposeSameIOEliminater : public AnfVisitor {
}
auto value = GetValueNode(tuple_);
if (value == nullptr || !value->isa<ValueSequence>()) {
return nullptr;
}
auto elements = GetValue<std::vector<int64_t>>(value);
if (elements.empty()) {
return nullptr;

View File

@ -118,6 +118,9 @@ class ReshapeInfer : public abstract::OpInferBase {
MS_EXCEPTION_IF_CHECK_FAIL(LongToSize(shape_value[0]) == shape_vector.size(),
"Illegal shape of shape value");
output_shape = shape_vector;
if (std::count_if(output_shape.begin(), output_shape.end(), [](int64_t s) { return s < 0; }) > 1) {
return std::make_shared<abstract::Shape>(output_shape);
}
}
}
}

View File

@ -31,7 +31,6 @@ class NetCdist(nn.Cell):
return self.cdist(x1, x2)
@pytest.mark.skip(reason="Error GetValue for value")
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu

View File

@ -32,10 +32,11 @@ class NetReshape(nn.Cell):
return self.reshape(x, self.target)
@pytest.mark.skip(reason="At most one component of input shape can be -1, but got [-1, -1]")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_reshape_shape():
"""
@ -49,10 +50,11 @@ def test_dynamic_reshape_shape():
test_dynamic.test_dynamic_grad_net(x)
@pytest.mark.skip(reason="At most one component of input shape can be -1, but got [-1, -1]")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_reshape_rank():
"""