Add dtypes for Padding and Softsign.

This commit is contained in:
liqiliang 2022-07-26 14:11:44 +08:00
parent 1b8a5ae512
commit 9a2e0c9cfe
7 changed files with 43 additions and 5 deletions

View File

@ -19,10 +19,16 @@
#include <algorithm>
#include <utility>
#include <memory>
#include <complex>
#include "mindspore/core/ops/padding.h"
namespace mindspore {
namespace kernel {
namespace {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
} // namespace
bool PaddingCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
@ -114,6 +120,10 @@ const std::vector<std::pair<KernelAttr, PaddingCpuKernelMod::KernelRunFunc>> &Pa
&PaddingCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&PaddingCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&PaddingCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&PaddingCpuKernelMod::LaunchKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&PaddingCpuKernelMod::LaunchKernel<bool>},
};

View File

@ -21,9 +21,12 @@
#include <memory>
#include "mindspore/core/ops/padding.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/padding_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
bool PaddingGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
@ -111,6 +114,10 @@ std::vector<std::pair<KernelAttr, PaddingGpuKernelMod::PaddingFunc>> PaddingGpuK
&PaddingGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&PaddingGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&PaddingGpuKernelMod::LaunchKernel<Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&PaddingGpuKernelMod::LaunchKernel<Complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&PaddingGpuKernelMod::LaunchKernel<bool>}};

View File

@ -17,6 +17,10 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/padding_impl.cuh"
#include "include/cuda_runtime.h"
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
__global__ void CalculatePaddingKernel(const T *input_ptr, size_t output_outer_size_, size_t pad_dim_size,
@ -66,6 +70,14 @@ template CUDA_LIB_EXPORT void CalculatePadding<float>(const float *input_ptr, si
template CUDA_LIB_EXPORT void CalculatePadding<double>(const double *input_ptr, size_t output_outer_size_,
size_t pad_dim_size, double *output_ptr,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalculatePadding<Complex<float>>(const Complex<float> *input_ptr,
size_t output_outer_size_, size_t pad_dim_size,
Complex<float> *output_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalculatePadding<Complex<double>>(const Complex<double> *input_ptr,
size_t output_outer_size_, size_t pad_dim_size,
Complex<double> *output_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalculatePadding<bool>(const bool *input_ptr, size_t output_outer_size_,
size_t pad_dim_size, bool *output_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -23,6 +23,7 @@
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ops {
@ -30,8 +31,16 @@ namespace {
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool};
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
bool is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
std::set<TypePtr> valid_types{};
if (is_gpu || is_cpu) {
valid_types = common_valid_types_with_complex_and_bool;
} else {
valid_types = common_valid_types_with_bool;
}
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, name);
}

View File

@ -398,7 +398,7 @@ def test_invert_dynamic_shape(mode):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(2,), (4, 5), (3, 4, 5, 6)])
@pytest.mark.parametrize('dtype, tol', [(np.float16, 1.0e-3), (np.float32, 1.0e-5)])
@pytest.mark.parametrize('dtype, tol', [(np.float16, 1.0e-3), (np.float32, 1.0e-4), (np.float64, 1.0e-5)])
def test_softsign(shape, dtype, tol):
"""
Feature: ALL To ALL

View File

@ -48,7 +48,7 @@ class PaddingDynamicShapeNet(nn.Cell):
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('shape', [(2, 1), (2, 4, 1), (3, 4, 5, 1)])
@pytest.mark.parametrize('dtype', [np.int32, np.float16, np.float32])
@pytest.mark.parametrize('dtype', [np.bool_, np.uint32, np.float16, np.float32, np.complex64, np.complex128])
@pytest.mark.parametrize('pad_dim_size', [2, 4, 10])
def test_padding(mode, shape, dtype, pad_dim_size):
"""

View File

@ -47,7 +47,7 @@ class PaddingDynamicShapeNet(nn.Cell):
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('shape', [(2, 1), (2, 4, 1), (3, 4, 5, 1)])
@pytest.mark.parametrize('dtype', [np.uint32, np.float16, np.float32])
@pytest.mark.parametrize('dtype', [np.bool_, np.uint32, np.float16, np.float32, np.complex64, np.complex128])
@pytest.mark.parametrize('pad_dim_size', [2, 4, 10])
def test_padding(mode, shape, dtype, pad_dim_size):
"""