forked from mindspore-Ecosystem/mindspore
commit
68c99f3ff7
|
@ -362,6 +362,22 @@ __global__ void RoundKernel(const double *input, double *output, const size_t co
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void SignKernel(const T *input, T *output, const size_t count) {
|
||||
T zero = 0.0;
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
T res;
|
||||
if (input[i] < zero) {
|
||||
res = -1;
|
||||
} else if (input[i] > zero) {
|
||||
res = 1;
|
||||
} else {
|
||||
res = 0;
|
||||
}
|
||||
output[i] = static_cast<T>(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -476,6 +492,11 @@ void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
|
|||
RoundKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Sign(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
SignKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
// double
|
||||
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -500,6 +521,7 @@ template void Abs<double>(const double *input, double *output, const size_t coun
|
|||
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sign<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
|
||||
// float
|
||||
|
@ -525,6 +547,7 @@ template void Abs<float>(const float *input, float *output, const size_t count,
|
|||
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sign<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// half
|
||||
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -549,6 +572,7 @@ template void Abs<half>(const half *input, half *output, const size_t count, cud
|
|||
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sign<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// int32
|
||||
template void Exponential<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -573,3 +597,4 @@ template void Abs<int>(const int *input, int *output, const size_t count, cudaSt
|
|||
template void Floor<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sign<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -62,5 +62,7 @@ template <typename T>
|
|||
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Sign(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
|
|
@ -122,5 +122,13 @@ MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,6 +50,7 @@ enum UnaryOptype {
|
|||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_RINT,
|
||||
UNARY_OP_ROUND,
|
||||
UNARY_OP_SIGN,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
|
@ -64,7 +65,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
|
||||
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
|
||||
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
|
||||
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND}};
|
||||
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND},
|
||||
{"Sign", UNARY_OP_SIGN}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryOpGpuKernel : public GpuKernel {
|
||||
|
@ -170,6 +172,10 @@ class UnaryOpGpuKernel : public GpuKernel {
|
|||
Round(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_SIGN: {
|
||||
Sign(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
|
|
|
@ -3883,7 +3883,7 @@ class Sign(PrimitiveWithInfer):
|
|||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[2.0, 0.0, -1.0]]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sign = ops.Sign()
|
||||
|
||||
def construct(self, x):
|
||||
return self.sign(x)
|
||||
|
||||
|
||||
def generate_testcases(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([2.0, 0.0, -1.0]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.array([1.0, 0.0, -1.0]).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x = np.array([2.0, 0.0, -1.0]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.array([1.0, 0.0, -1.0]).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_int32():
|
||||
generate_testcases(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float32():
|
||||
generate_testcases(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float16():
|
||||
generate_testcases(np.float16)
|
Loading…
Reference in New Issue