forked from mindspore-Ecosystem/mindspore
Add dtypes for Padding and Softsign.
This commit is contained in:
parent
1b8a5ae512
commit
9a2e0c9cfe
|
@ -19,10 +19,16 @@
|
|||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <complex>
|
||||
#include "mindspore/core/ops/padding.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
} // namespace
|
||||
|
||||
bool PaddingCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
|
@ -114,6 +120,10 @@ const std::vector<std::pair<KernelAttr, PaddingCpuKernelMod::KernelRunFunc>> &Pa
|
|||
&PaddingCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&PaddingCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&PaddingCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&PaddingCpuKernelMod::LaunchKernel<complex128>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&PaddingCpuKernelMod::LaunchKernel<bool>},
|
||||
};
|
||||
|
|
|
@ -21,9 +21,12 @@
|
|||
#include <memory>
|
||||
#include "mindspore/core/ops/padding.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/padding_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
bool PaddingGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
|
@ -111,6 +114,10 @@ std::vector<std::pair<KernelAttr, PaddingGpuKernelMod::PaddingFunc>> PaddingGpuK
|
|||
&PaddingGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&PaddingGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&PaddingGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&PaddingGpuKernelMod::LaunchKernel<Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&PaddingGpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/padding_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalculatePaddingKernel(const T *input_ptr, size_t output_outer_size_, size_t pad_dim_size,
|
||||
|
@ -66,6 +70,14 @@ template CUDA_LIB_EXPORT void CalculatePadding<float>(const float *input_ptr, si
|
|||
template CUDA_LIB_EXPORT void CalculatePadding<double>(const double *input_ptr, size_t output_outer_size_,
|
||||
size_t pad_dim_size, double *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalculatePadding<Complex<float>>(const Complex<float> *input_ptr,
|
||||
size_t output_outer_size_, size_t pad_dim_size,
|
||||
Complex<float> *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalculatePadding<Complex<double>>(const Complex<double> *input_ptr,
|
||||
size_t output_outer_size_, size_t pad_dim_size,
|
||||
Complex<double> *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalculatePadding<bool>(const bool *input_ptr, size_t output_outer_size_,
|
||||
size_t pad_dim_size, bool *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -30,8 +31,16 @@ namespace {
|
|||
TypePtr PaddingInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool};
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
bool is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
|
||||
std::set<TypePtr> valid_types{};
|
||||
if (is_gpu || is_cpu) {
|
||||
valid_types = common_valid_types_with_complex_and_bool;
|
||||
} else {
|
||||
valid_types = common_valid_types_with_bool;
|
||||
}
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, name);
|
||||
}
|
||||
|
||||
|
|
|
@ -398,7 +398,7 @@ def test_invert_dynamic_shape(mode):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(2,), (4, 5), (3, 4, 5, 6)])
|
||||
@pytest.mark.parametrize('dtype, tol', [(np.float16, 1.0e-3), (np.float32, 1.0e-5)])
|
||||
@pytest.mark.parametrize('dtype, tol', [(np.float16, 1.0e-3), (np.float32, 1.0e-4), (np.float64, 1.0e-5)])
|
||||
def test_softsign(shape, dtype, tol):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
|
|
|
@ -48,7 +48,7 @@ class PaddingDynamicShapeNet(nn.Cell):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('shape', [(2, 1), (2, 4, 1), (3, 4, 5, 1)])
|
||||
@pytest.mark.parametrize('dtype', [np.int32, np.float16, np.float32])
|
||||
@pytest.mark.parametrize('dtype', [np.bool_, np.uint32, np.float16, np.float32, np.complex64, np.complex128])
|
||||
@pytest.mark.parametrize('pad_dim_size', [2, 4, 10])
|
||||
def test_padding(mode, shape, dtype, pad_dim_size):
|
||||
"""
|
||||
|
|
|
@ -47,7 +47,7 @@ class PaddingDynamicShapeNet(nn.Cell):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('shape', [(2, 1), (2, 4, 1), (3, 4, 5, 1)])
|
||||
@pytest.mark.parametrize('dtype', [np.uint32, np.float16, np.float32])
|
||||
@pytest.mark.parametrize('dtype', [np.bool_, np.uint32, np.float16, np.float32, np.complex64, np.complex128])
|
||||
@pytest.mark.parametrize('pad_dim_size', [2, 4, 10])
|
||||
def test_padding(mode, shape, dtype, pad_dim_size):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue