!34184 fix the bug of StridedSlice op and DynamicStitch op

Merge pull request !34184 from hanhuifeng/asr_bug
This commit is contained in:
i-robot 2022-05-11 03:13:12 +00:00 committed by Gitee
commit be52fe556e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 56 additions and 7 deletions

View File

@ -40,7 +40,11 @@ bool DynamicStitchKernelMod::Init(const CNodePtr &kernel_node) {
size_t index_type_size = sizeof(int);
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, true));
auto first_data_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, n_);
one_data_ele_num_ = first_data_shape[first_data_shape.size() - 1];
auto first_index_dims = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0).size();
one_data_ele_num_ = 1;
for (size_t d = first_index_dims; d < first_data_shape.size(); ++d) {
one_data_ele_num_ *= first_data_shape[d];
}
for (size_t i = 0; i < n_; i++) {
auto data_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, n_ + i);
size_t data_size = std::accumulate(data_shape.begin(), data_shape.end(), 1, std::multiplies<size_t>());

View File

@ -375,7 +375,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
}
ret_in_shape = DynamicComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape, begin_len);
if (max_shape.empty() || min_shape.empty()) {
if (x_is_dyn && (max_shape.empty() || min_shape.empty())) {
return std::make_shared<abstract::Shape>(ret_in_shape);
}

View File

@ -3632,11 +3632,9 @@ class StridedSlice(PrimitiveWithInfer):
rets = {'shape': ret_shape,
'dtype': x['dtype'],
'value': None}
if max_shape is not None and min_shape is not None:
rets = self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
return rets
if -1 in x_shape and (max_shape is None or min_shape is None):
return rets
return self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
if all(ret_shape):

View File

@ -55,3 +55,50 @@ def test_net_int32():
net = Net()
output = net(indices, data)
assert np.array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_1():
"""
Feature: Test dynamicstitch op.
Description: An index corresponds to a number
Expectation: the result match expected array.
"""
x1 = Tensor(np.array([0, 1]), mindspore.int32)
x2 = Tensor([1], mindspore.int32)
y1 = Tensor(np.array([1, 3]), mindspore.int32)
y2 = Tensor(np.array([2]), mindspore.int32)
expected = np.array([1, 2]).astype(np.int32)
indices = [x1, x2]
data = [y1, y2]
net = Net()
output = net(indices, data)
assert np.array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_2():
"""
Feature: Test dynamicstitch op.
Description: An index corresponds to a multidimensional array.
Expectation: the result match expected array.
"""
x1 = Tensor(np.array([0, 2]), mindspore.int32)
x2 = Tensor([1], mindspore.int32)
y1 = Tensor(np.array([[[1, 2, 3], [4, 5, 6]],
[[13, 14, 15], [16, 17, 18]]]), mindspore.int32)
y2 = Tensor(np.array([[[7, 8, 9], [10, 11, 12]]]), mindspore.int32)
expected = np.array([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]]]).astype(np.int32)
indices = [x1, x2]
data = [y1, y2]
net = Net()
output = net(indices, data)
assert np.array_equal(output.asnumpy(), expected)