forked from mindspore-Ecosystem/mindspore
fix slice op bug
This commit is contained in:
parent
f0016f5574
commit
fb64e14265
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -102,8 +102,10 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeBool) {
|
||||
ret = LaunchKernel<bool>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
ret = LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
|
||||
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
|
||||
return false;
|
||||
}
|
||||
return ret;
|
||||
|
|
|
@ -55,9 +55,14 @@ class SliceCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
|
|
|
@ -86,8 +86,10 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeBool) {
|
||||
ret = LaunchKernel<bool>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
ret = LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
|
||||
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
|
||||
return false;
|
||||
}
|
||||
return ret;
|
||||
|
|
|
@ -60,10 +60,23 @@ MS_REG_CPU_KERNEL(
|
|||
SliceGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
SliceGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SliceGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SliceGradCPUKernel);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -293,7 +293,7 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
|||
TypeError: If `quant_delay` is not greater than or equal to 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> fake_quant = nn.FakeQuantWithMinMaxObserver()
|
||||
|
@ -448,7 +448,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
|||
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> qconfig = compression.quant.create_quant_config()
|
||||
|
|
|
@ -4572,7 +4572,7 @@ class BroadcastTo(PrimitiveWithInfer):
|
|||
target shape is in an invalid location.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> shape = (2, 3)
|
||||
|
|
|
@ -78,6 +78,21 @@ def test_slice_grad2():
|
|||
[[0., 0.], [8., 9.], [10., 11.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
def test_slice_grad3():
|
||||
x = Tensor(np.array([[[1.0, 3.5, 5.8], [2.5, 4, 1]], [[3.5, 15.3, 3.1], [2.2, 4.0, 1.1]],
|
||||
[[43.4, 1.1, 12.1], [2.4, 6.5, 6.3]]]), mstype.float64)
|
||||
dy = Tensor(np.array([[[3.1, 1.1, 2.2]], [[4.4, 1.2, 4.2]]]), mstype.float64)
|
||||
slicegrad = SliceGrad()
|
||||
output = slicegrad(dy, x)
|
||||
expect = [[[0., 0., 0.],
|
||||
[3.1, 1.1, 2.2]],
|
||||
[[0., 0., 0.],
|
||||
[4.4, 1.2, 4.2]],
|
||||
[[0., 0., 0.],
|
||||
[0., 0., 0.]]]
|
||||
print("output:\n", output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
class StridedSliceGrad(nn.Cell):
|
||||
def __init__(self, x, begin, end, stride):
|
||||
super(StridedSliceGrad, self).__init__()
|
||||
|
|
|
@ -69,6 +69,14 @@ def test_slice2():
|
|||
output = slice_op(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
def test_slice_float64():
|
||||
data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
|
||||
[[3, 3, 3], [4, 4, 4]],
|
||||
[[5, 5, 5], [6, 6, 6]]]).astype(np.float64))
|
||||
slice_op = P.Slice()
|
||||
output = slice_op(data, (1, 0, 0), (1, 1, 3))
|
||||
expect = [[[3.0, 3.0, 3.0]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
class Slice3(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue