!39856 The dynamic shape capability of the inventory of the Range operator is supported. The interfaces and data types are supplemented.
Merge pull request !39856 from NaCN/range_data
This commit is contained in:
commit
fba8e11bcd
|
@ -8,9 +8,9 @@ mindspore.ops.range
|
|||
三个输入的数据类型必须相同。函数返回的Tensor的数据类型与输入数据类型保持一致。
|
||||
|
||||
参数:
|
||||
- **start** (Tensor) - 标量Tensor,序列中的第一个数字。数据类型必须为int32或者float32。
|
||||
- **limit** (Tensor) - 标量Tensor,序列中的数值上线,不包括其本身。数据类型必须为int32或者float32。
|
||||
- **delta** (Tensor) - 标量Tensor,表述序列中数值的步长。数据类型必须为int32或者float32。
|
||||
- **start** (Tensor) - 标量Tensor,序列中的第一个数字。数据类型必须为int32,int64,float32或者float64。
|
||||
- **limit** (Tensor) - 标量Tensor,序列中的数值上线,不包括其本身。数据类型必须为int32,int64,float32或者float64。
|
||||
- **delta** (Tensor) - 标量Tensor,表述序列中数值的步长。数据类型必须为int32,int64,float32或者float64。
|
||||
|
||||
返回:
|
||||
一维Tensor,数据类型与输入数据类型一致。
|
||||
|
|
|
@ -22,6 +22,17 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kRangeInputsNum = 3;
|
||||
constexpr size_t kRangeOutputsNum = 1;
|
||||
|
||||
template <typename T>
|
||||
T Sign(T num) {
|
||||
if (num > static_cast<T>(0.0)) {
|
||||
return static_cast<T>(1.0);
|
||||
} else if (num == static_cast<T>(0.0)) {
|
||||
return static_cast<T>(0.0);
|
||||
} else {
|
||||
return static_cast<T>(-1.0);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void RangeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
|
@ -40,6 +51,10 @@ bool RangeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
LaunchKernel<int32_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be int or float, but got "
|
||||
<< TypeIdLabel(dtype_);
|
||||
|
@ -66,17 +81,16 @@ void RangeCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, cons
|
|||
}
|
||||
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t max_index = outputs[0]->size / sizeof(T) - 1;
|
||||
size_t index = 0;
|
||||
while ((delta > 0 && start < limit) || (delta < 0 && start > limit)) {
|
||||
if (index > max_index) {
|
||||
size_t max_size = outputs[0]->size / sizeof(T);
|
||||
if (Sign(delta) * Sign(limit - start) > 0) {
|
||||
output_size_ = static_cast<size_t>(std::ceil(static_cast<double>(limit - start) / static_cast<double>(delta)));
|
||||
if (output_size_ > max_size) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the output element number exceeds the maximum number.";
|
||||
}
|
||||
output[index] = start;
|
||||
start += delta;
|
||||
index++;
|
||||
for (size_t index = 0; index < output_size_; index++, start += delta) {
|
||||
output[index] = start;
|
||||
}
|
||||
}
|
||||
output_size_ = index;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Range, RangeCpuKernelMod);
|
||||
|
|
|
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
|
|||
RangeGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
RangeGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
RangeGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
RangeGpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,5 +34,11 @@ void CalRange(const int size, const float start, const float limit, const float
|
|||
template CUDA_LIB_EXPORT void CalRange<float>(const int size, const float start, const float limit, const float delta,
|
||||
const float *input, float *output, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalRange<double>(const int size, const float start, const float limit, const float delta,
|
||||
const double *input, double *output, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalRange<int>(const int size, const float start, const float limit, const float delta,
|
||||
const int *input, int *output, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalRange<int64_t>(const int size, const float start, const float limit, const float delta,
|
||||
const int64_t *input, int64_t *output, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -22,7 +22,9 @@ range_op_info = CpuRegOp("Range") \
|
|||
.input(2, "delta") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -489,11 +489,11 @@ def range(start, limit, delta):
|
|||
|
||||
Args:
|
||||
start (Tensor): A scalar Tensor. The first number in the sequence. Must have
|
||||
type: int32 or float32.
|
||||
type: int32 ,int64, float32 or float64.
|
||||
limit (Tensor): A scalar Tensor. Upper limit of the sequence, exclusive. Must
|
||||
have type: int32 or float32.
|
||||
have type: int32 ,int64, float32 or float64.
|
||||
delta (Tensor): A scalar Tensor. Number that increments `start`. Must have
|
||||
type: int32 or float32.
|
||||
type: int32 ,int64, float32 or float64.
|
||||
|
||||
Returns:
|
||||
A 1-D Tensor, with the same type as the inputs.
|
||||
|
|
|
@ -6209,7 +6209,7 @@ class Range(PrimitiveWithCheck):
|
|||
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
|
||||
valid_dtypes = [mstype.int32, mstype.float32]
|
||||
valid_dtypes = [mstype.int32, mstype.float32, mstype.int64, mstype.float64]
|
||||
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
|
||||
|
||||
|
|
|
@ -88,6 +88,36 @@ def test_range_op_float():
|
|||
assert np.array_equal(result.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_op_int64():
|
||||
"""
|
||||
Feature: test Range op on CPU.
|
||||
Description: test the Range when input is int64.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
range_op = ms.ops.Range()
|
||||
result = range_op(ms.Tensor(2, ms.int64), ms.Tensor(5, ms.int64), ms.Tensor(2, ms.int64))
|
||||
expect = np.array([2, 4], np.int64)
|
||||
assert np.array_equal(result.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_op_float64():
|
||||
"""
|
||||
Feature: test Range op on CPU.
|
||||
Description: test the Range when input is float64.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
range_op = ms.ops.Range()
|
||||
result = range_op(ms.Tensor(2, ms.float64), ms.Tensor(5, ms.float64), ms.Tensor(1, ms.float64))
|
||||
expect = np.array([2, 3, 4], np.float64)
|
||||
assert np.array_equal(result.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -65,6 +65,7 @@ def test_range_precision_end_equals_last_element():
|
|||
np_expected = np.arange(-12000, -12053, -1, dtype=np.float32)
|
||||
np.testing.assert_allclose(ms_out, np_expected, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -89,6 +90,7 @@ def test_range_int():
|
|||
np_expected = np.array([3, -2, -7])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -113,6 +115,67 @@ def test_range_float():
|
|||
np_expected = np.array([1.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_int64():
|
||||
"""
|
||||
Feature: test Range op on GPU.
|
||||
Description: test the Range when input is int64.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2, mstype.int64), Tensor(5, mstype.int64), Tensor(1, mstype.int64)).asnumpy()
|
||||
np_expected = np.array([2, 3, 4])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-24, mstype.int64), Tensor(1, mstype.int64), Tensor(4, mstype.int64)).asnumpy()
|
||||
np_expected = np.array([-24, -20, -16, -12, -8, -4, 0])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8, mstype.int64), Tensor(1, mstype.int64), Tensor(-1, mstype.int64)).asnumpy()
|
||||
np_expected = np.array([8, 7, 6, 5, 4, 3, 2])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(3, mstype.int64), Tensor(-11, mstype.int64), Tensor(-5, mstype.int64)).asnumpy()
|
||||
np_expected = np.array([3, -2, -7])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_float64():
|
||||
"""
|
||||
Feature: test Range op on GPU.
|
||||
Description: test the Range when input is float64.
|
||||
Expectation: result is right.
|
||||
"""
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2.3, mstype.float64), Tensor(5.5, mstype.float64), Tensor(1.2, mstype.float64)).asnumpy()
|
||||
np_expected = np.array([2.3, 3.5, 4.7])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-4, mstype.float64), Tensor(-1, mstype.float64), Tensor(1.5, mstype.float64)).asnumpy()
|
||||
np_expected = np.array([-4.0, -2.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8.0, mstype.float64), Tensor(1.0, mstype.float64), Tensor(-1.0, mstype.float64)).asnumpy()
|
||||
np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(1.5, mstype.float64), Tensor(-1, mstype.float64), Tensor(-18.9, mstype.float64)).asnumpy()
|
||||
np_expected = np.array([1.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -123,6 +186,7 @@ def test_range_invalid_max_output_length():
|
|||
_ = P.Range(None)
|
||||
_ = P.Range('5')
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue