forked from mindspore-Ecosystem/mindspore
!3560 Add random uniform real op at GPU end
Merge pull request !3560 from peixu_ren/custom_gpu
This commit is contained in:
commit
efb12d5e1e
|
@ -24,6 +24,18 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void UniformKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
|
||||
T *input2, size_t input_size_2, T *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
input1[i] = (input_size_1 == 1 ? input1[0] : input1[i]);
|
||||
input2[i] = (input_size_2 == 1 ? input2[0] : input2[i]);
|
||||
curand_init(seed, i, 0, &globalState[i]);
|
||||
output[i] = curand_uniform(&globalState[i]) * (input2[i] - input1[i]) + input1[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
int RNG_seed = 0;
|
||||
|
@ -38,5 +50,17 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void UniformReal(int seed, curandState *globalState, T *input1, size_t input_size_1,
|
||||
T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
seed = (seed == 0 ? time(NULL):seed);
|
||||
UniformKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
|
||||
(seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void StandardNormal<float>(int seed, int seed2, curandState *globalState,
|
||||
float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void UniformReal<float>(int seed, curandState *globalState, float *input1, size_t input_size_1,
|
||||
float *input2, size_t input_size_2, float *output, size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -23,4 +23,8 @@
|
|||
template <typename T>
|
||||
void StandardNormal(int seed, int seed2, curandState *globalState,
|
||||
T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void UniformReal(int seed, curandState *globalState,
|
||||
T *input1, size_t input_size_1, T *input2, size_t input_size_2,
|
||||
T *output, size_t count, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
|
||||
|
|
|
@ -20,5 +20,12 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
RandomOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UniformReal,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
RandomOpGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,17 +28,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 };
|
||||
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 };
|
||||
|
||||
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}};
|
||||
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL},
|
||||
{"UniformReal", RANDOM_OP_UNIFORM_REAL}};
|
||||
template <typename T>
|
||||
class RandomOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
RandomOpGpuKernel()
|
||||
: random_op_type_(RANDOM_OP_INVALID_TYPE),
|
||||
input_size_0_(0),
|
||||
input_size_0_(sizeof(int)),
|
||||
input_size_1_(sizeof(T)),
|
||||
input_size_2_(sizeof(T)),
|
||||
output_size_(sizeof(T)),
|
||||
workspace_size_(sizeof(curandState)) {}
|
||||
workspace_size_(sizeof(curandState)),
|
||||
seed_(0),
|
||||
seed2_(0) {}
|
||||
~RandomOpGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case RANDOM_OP_UNIFORM_REAL: {
|
||||
T *input_addr_1 = GetDeviceAddress<T>(inputs, 1);
|
||||
T *input_addr_2 = GetDeviceAddress<T>(inputs, 2);
|
||||
UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
|
||||
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto iter = kRandomOpTypeMap.find(kernel_name);
|
||||
|
@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
random_op_type_ = iter->second;
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output.";
|
||||
|
@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
input_size_0_ += input_shape_0[i];
|
||||
}
|
||||
input_size_0_ *= sizeof(int);
|
||||
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) {
|
||||
auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < input_shape_1.size(); i++) {
|
||||
input_size_1_ *= input_shape_1[i];
|
||||
}
|
||||
auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
for (size_t i = 0; i < input_shape_2.size(); i++) {
|
||||
input_size_2_ *= input_shape_2[i];
|
||||
}
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
workspace_size_ *= output_shape[i];
|
||||
}
|
||||
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
|
||||
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"));
|
||||
if (random_op_type_ == RANDOM_OP_NORMAL) {
|
||||
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"));
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_0_);
|
||||
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) {
|
||||
input_size_list_.push_back(input_size_1_);
|
||||
input_size_list_.push_back(input_size_2_);
|
||||
}
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
}
|
||||
|
@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel {
|
|||
private:
|
||||
RandomOptype random_op_type_;
|
||||
size_t input_size_0_;
|
||||
size_t input_size_1_;
|
||||
size_t input_size_2_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
int seed_;
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, shape, seed=0):
|
||||
super(Net, self).__init__()
|
||||
self.uniformreal = P.UniformReal(seed=seed)
|
||||
self.shape = shape
|
||||
|
||||
def construct(self, a, b):
|
||||
return self.uniformreal(self.shape, a, b)
|
||||
|
||||
|
||||
def test_net_1D():
|
||||
seed = 10
|
||||
shape = (3, 2, 4)
|
||||
a = 0.0
|
||||
b = 1.0
|
||||
net = Net(shape, seed)
|
||||
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
|
||||
output = net(ta, tb)
|
||||
print(output.asnumpy())
|
||||
assert output.shape == (3, 2, 4)
|
Loading…
Reference in New Issue