forked from mindspore-Ecosystem/mindspore
!43523 fix Reshape dynamic shape infer shape and transpose_eliminate
Merge pull request !43523 from looop5/fix_reshape
This commit is contained in:
commit
cedd215b56
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,6 +42,9 @@ class TransposeSameIOEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto value = GetValueNode(tuple_);
|
||||
if (value == nullptr || !value->isa<ValueSequence>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto elements = GetValue<std::vector<int64_t>>(value);
|
||||
if (elements.empty()) {
|
||||
return nullptr;
|
||||
|
|
|
@ -118,6 +118,9 @@ class ReshapeInfer : public abstract::OpInferBase {
|
|||
MS_EXCEPTION_IF_CHECK_FAIL(LongToSize(shape_value[0]) == shape_vector.size(),
|
||||
"Illegal shape of shape value");
|
||||
output_shape = shape_vector;
|
||||
if (std::count_if(output_shape.begin(), output_shape.end(), [](int64_t s) { return s < 0; }) > 1) {
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ class NetCdist(nn.Cell):
|
|||
return self.cdist(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Error GetValue for value")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
|
@ -32,10 +32,11 @@ class NetReshape(nn.Cell):
|
|||
return self.reshape(x, self.target)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="At most one component of input shape can be -1, but got [-1, -1]")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_reshape_shape():
|
||||
"""
|
||||
|
@ -49,10 +50,11 @@ def test_dynamic_reshape_shape():
|
|||
test_dynamic.test_dynamic_grad_net(x)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="At most one component of input shape can be -1, but got [-1, -1]")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_reshape_rank():
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue