add all types support to StridedSlice GPU

This commit is contained in:
TFBunny 2021-02-04 16:46:27 -05:00
parent 537b6a31da
commit b45a9f56c8
9 changed files with 188 additions and 37 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,10 +22,20 @@ MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32
StridedSliceGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
StridedSliceGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
StridedSliceGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
StridedSliceGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
StridedSliceGpuKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
StridedSliceGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
StridedSliceGpuKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
StridedSliceGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
StridedSliceGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
StridedSliceGpuKernel, uchar)
MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H
#include <vector>
#include <bitset>
@ -210,4 +210,4 @@ class StridedSliceGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,10 +22,20 @@ MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFlo
StridedSliceGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
StridedSliceGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
StridedSliceGradGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
StridedSliceGradGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
StridedSliceGradGpuKernel, short) // NOLINT
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
StridedSliceGradGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
StridedSliceGradGpuKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
StridedSliceGradGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
StridedSliceGradGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
StridedSliceGradGpuKernel, uchar)
MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H
#include <vector>
#include <bitset>
@ -211,4 +211,4 @@ class StridedSliceGradGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H

View File

@ -159,7 +159,6 @@ void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int
dy, dx);
}
template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream);
@ -167,7 +166,6 @@ template void CalSliceGrad<float>(const size_t input_size, const float *dy, cons
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
@ -175,7 +173,6 @@ template void CalSliceGrad<half>(const size_t input_size, const half *dy, const
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
@ -183,8 +180,6 @@ template void CalSliceGrad<int>(const size_t input_size, const int *dy, const st
const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, // NOLINT
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const short *input, short *output, // NOLINT
@ -195,8 +190,6 @@ template void CalSliceGrad<short>(const size_t input_size, const short *dy, //
short *output, // NOLINT
cudaStream_t cuda_stream);
template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned char *addr, const float value,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
@ -206,8 +199,6 @@ template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigne
const std::vector<int64_t> size, unsigned char *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
@ -216,7 +207,6 @@ template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
@ -224,34 +214,71 @@ template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const
const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, // NOLINT
cudaStream_t cuda_stream);
template void FillDeviceArray<int8_t>(const size_t input_size, int8_t *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<uint64_t>(const size_t input_size, uint64_t *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<uint32_t>(const size_t input_size, uint32_t *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<uint16_t>(const size_t input_size, uint16_t *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned char *addr, const float value,
cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const bool *input, bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const float *input, float *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const half *input, half *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int64_t *input, int64_t *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int *input, int *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const short *input, short *output, cudaStream_t cuda_stream); // NOLINT
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int8_t *input, int8_t *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const uint64_t *input, uint64_t *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const uint32_t *input, uint32_t *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const uint16_t *input, uint16_t *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const bool *input, bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int64_t *input, int64_t *output, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const float *dy, float *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const half *dy,
half *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const int64_t *dy, int64_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const int *dy,
int *dx, cudaStream_t cuda_stream);
@ -261,10 +288,16 @@ template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::v
short *dx, cudaStream_t cuda_stream); // NOLINT
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream);
const int8_t *dy, int8_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const int64_t *dy, int64_t *dx, cudaStream_t cuda_stream);
const uint64_t *dy, uint64_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const uint32_t *dy, uint32_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const uint16_t *dy, uint16_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_
#include <cuda_runtime.h>
#include <vector>
@ -39,4 +39,4 @@ void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int
cudaStream_t cuda_stream);
template <typename T>
void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_IMPL_CUH_

View File

@ -43,9 +43,21 @@ class PrintNetTwoInputs(nn.Cell):
return x
class PrintNetIndex(nn.Cell):
def __init__(self):
super(PrintNetIndex, self).__init__()
self.op = P.Print()
def construct(self, x):
self.op(x[0][0][6][3])
return x
def print_testcase(nptype):
# large shape
x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype)
# a value that can be stored as int8_t
x[0][0][6][3] = 125
# small shape
y = np.arange(9).reshape(3, 3).astype(nptype)
x = Tensor(x)
@ -54,8 +66,10 @@ def print_testcase(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net_1 = PrintNetOneInput()
net_2 = PrintNetTwoInputs()
net_3 = PrintNetIndex()
net_1(x)
net_2(x, y)
net_3(x)
@pytest.mark.level0

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -245,12 +245,54 @@ def strided_slice_grad(nptype):
def test_strided_slice_grad_float32():
strided_slice_grad(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_float16():
strided_slice_grad(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_int64():
strided_slice_grad(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_int32():
strided_slice_grad(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_int16():
strided_slice_grad(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_int8():
strided_slice_grad(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_uint64():
strided_slice_grad(np.uint64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_uint32():
strided_slice_grad(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_grad_uint16():
strided_slice_grad(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -108,12 +108,54 @@ def strided_slice(nptype):
def test_strided_slice_float32():
strided_slice(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_float16():
strided_slice(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_int64():
strided_slice(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_int32():
strided_slice(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_int16():
strided_slice(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_int8():
strided_slice(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_uint64():
strided_slice(np.uint64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_uint32():
strided_slice(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_strided_slice_uint16():
strided_slice(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard