forked from mindspore-Ecosystem/mindspore
!34184 fix the bug of StridedSlice op and DynamicStitch op
Merge pull request !34184 from hanhuifeng/asr_bug
This commit is contained in:
commit
be52fe556e
|
@ -40,7 +40,11 @@ bool DynamicStitchKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
size_t index_type_size = sizeof(int);
|
||||
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, true));
|
||||
auto first_data_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, n_);
|
||||
one_data_ele_num_ = first_data_shape[first_data_shape.size() - 1];
|
||||
auto first_index_dims = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0).size();
|
||||
one_data_ele_num_ = 1;
|
||||
for (size_t d = first_index_dims; d < first_data_shape.size(); ++d) {
|
||||
one_data_ele_num_ *= first_data_shape[d];
|
||||
}
|
||||
for (size_t i = 0; i < n_; i++) {
|
||||
auto data_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, n_ + i);
|
||||
size_t data_size = std::accumulate(data_shape.begin(), data_shape.end(), 1, std::multiplies<size_t>());
|
||||
|
|
|
@ -375,7 +375,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
ret_in_shape = DynamicComputeInferShape(primitive, begin_v, end_v, strides_v, x_shape, begin_len);
|
||||
|
||||
if (max_shape.empty() || min_shape.empty()) {
|
||||
if (x_is_dyn && (max_shape.empty() || min_shape.empty())) {
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -3632,11 +3632,9 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
rets = {'shape': ret_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
||||
if max_shape is not None and min_shape is not None:
|
||||
rets = self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
|
||||
|
||||
return rets
|
||||
if -1 in x_shape and (max_shape is None or min_shape is None):
|
||||
return rets
|
||||
return self._compute_max_min_shape(rets, x_shape, max_shape, min_shape, ret_shape)
|
||||
|
||||
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
|
||||
if all(ret_shape):
|
||||
|
|
|
@ -55,3 +55,50 @@ def test_net_int32():
|
|||
net = Net()
|
||||
output = net(indices, data)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_1():
|
||||
"""
|
||||
Feature: Test dynamicstitch op.
|
||||
Description: An index corresponds to a number
|
||||
Expectation: the result match expected array.
|
||||
"""
|
||||
x1 = Tensor(np.array([0, 1]), mindspore.int32)
|
||||
x2 = Tensor([1], mindspore.int32)
|
||||
y1 = Tensor(np.array([1, 3]), mindspore.int32)
|
||||
y2 = Tensor(np.array([2]), mindspore.int32)
|
||||
expected = np.array([1, 2]).astype(np.int32)
|
||||
|
||||
indices = [x1, x2]
|
||||
data = [y1, y2]
|
||||
net = Net()
|
||||
output = net(indices, data)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_2():
|
||||
"""
|
||||
Feature: Test dynamicstitch op.
|
||||
Description: An index corresponds to a multidimensional array.
|
||||
Expectation: the result match expected array.
|
||||
"""
|
||||
x1 = Tensor(np.array([0, 2]), mindspore.int32)
|
||||
x2 = Tensor([1], mindspore.int32)
|
||||
y1 = Tensor(np.array([[[1, 2, 3], [4, 5, 6]],
|
||||
[[13, 14, 15], [16, 17, 18]]]), mindspore.int32)
|
||||
y2 = Tensor(np.array([[[7, 8, 9], [10, 11, 12]]]), mindspore.int32)
|
||||
expected = np.array([[[1, 2, 3], [4, 5, 6]],
|
||||
[[7, 8, 9], [10, 11, 12]],
|
||||
[[13, 14, 15], [16, 17, 18]]]).astype(np.int32)
|
||||
|
||||
indices = [x1, x2]
|
||||
data = [y1, y2]
|
||||
net = Net()
|
||||
output = net(indices, data)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
|
Loading…
Reference in New Issue