!8068 add supports to op gathergrad on gpu

Merge pull request !8068 from zhouyuanshen/gathergrad
This commit is contained in:
mindspore-ci-bot 2020-10-31 15:58:33 +08:00 committed by Gitee
commit ccbc6df79c
7 changed files with 433 additions and 0 deletions

View File

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GatherGradGpuKernel, int, float)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GatherGradGpuKernel, int64_t, float)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GatherGradGpuKernel, int, half)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GatherGradGpuKernel, int64_t, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,124 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_GATHER_GRAD_GPU_KERNEL_H
#define MINDSPORE_GATHER_GRAD_GPU_KERNEL_H
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class GatherGradGpuKernel : public GpuKernel {
public:
GatherGradGpuKernel() : axis_(0), handle_(nullptr) {}
~GatherGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *index_addr = GetDeviceAddress<T>(inputs, 0);
S *grad_addr = GetDeviceAddress<S>(inputs, 1);
S *output_addr = GetDeviceAddress<S>(outputs, 0);
GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2],
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGradGpuKernel needs 2.";
}
index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
grad_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = GetAttr<int>(kernel_node, "dim");
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(index_shapes_.size());
}
Reshape();
InitSizeLists();
return true;
}
protected:
void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() override {
size_t size = GetSize(index_shapes_, true);
input_size_list_.push_back(size);
size = GetSize(grad_shapes_, false);
input_size_list_.push_back(size);
size = GetSize(output_shapes_, false);
output_size_list_.push_back(size);
}
private:
void Reshape() {
size_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_indices = output_shapes_[IntToSize(axis_)];
size_t dim_after_indices = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_indices;
dims_[2] = dim_after_indices;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
if (shape.size() == 0) {
return 0;
}
size_t result = flag ? sizeof(T) : sizeof(S);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
std::vector<size_t> index_shapes_;
std::vector<size_t> grad_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
int axis_;
cudnnHandle_t handle_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_GATHER_GRAD_GPU_KERNEL_H

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2) {
size_t num = output_dim0 * output_dim1 * output_dim2;
size_t i, k;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
id += blockDim.x * gridDim.x) {
i = id / (output_dim1 * output_dim2) % output_dim0;
k = id % output_dim2;
size_t j_read = static_cast<size_t>(index[id]);
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
output[read_id] = grad[id];
}
return;
}
template <typename T, typename S>
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(index, grad, output,
output_dim0, output_dim1, output_dim2);
return;
}
template void GatherGrad<int, float>(const int *index, const float *grad, float *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
template void GatherGrad<int, half>(const int *index, const half *grad, half *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
template void GatherGrad<int64_t, float>(const int64_t *index, const float *grad, float *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
template void GatherGrad<int64_t, half>(const int64_t *index, const half *grad, half *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H
#define MINDSPORE_GATHER_GRAD_GPU_CU_H
template <typename T, typename S>
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
#endif

View File

@ -378,6 +378,15 @@ def get_bprop_gather_v2(self):
return bprop
@bprop_getters.register(P.GatherD)
def get_bprop_gather_d(self):
def bprop(x, dim, index, out, dout):
return P.GatherDGrad(dim)(index, dout)
return bprop
@bprop_getters.register(P.SparseGatherV2)
def get_bprop_sparse_gather_v2(self):
"""Generate bprop for SparseGatherV2"""

View File

@ -1234,6 +1234,23 @@ class EluGrad(PrimitiveWithInfer):
return x_dtype
class GatherDGrad(PrimitiveWithInfer):
"""Performs grad of GatherD operation."""
@prim_attr_register
def __init__(self, dim=0):
"""Initialize GatherDGrad"""
validator.check_is_int(dim, int)
self.add_prim_attr("dim", dim)
self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
def infer_shape(self, index_shape, grad_shape):
return grad_shape
def infer_dtype(self, index_dtype, grad_dtype):
return grad_dtype
class ResizeBilinearGrad(PrimitiveWithInfer):
"""Performs grad of ResizeBilinear operation."""

View File

@ -0,0 +1,163 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops.operations._grad_ops as P
from mindspore import Tensor
class GatherDGradNet(nn.Cell):
def __init__(self, dim=0):
super(GatherDGradNet, self).__init__()
self.gather_d_grad = P.GatherDGrad(dim)
def construct(self, index, grad):
return self.gather_d_grad(index, grad)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_graph_int32_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
net = GatherDGradNet(dim)
output = net(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_graph_int64_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
net = GatherDGradNet(dim)
output = net(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_graph_int32_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
net = GatherDGradNet(dim)
output = net(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_graph_int64_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
net = GatherDGradNet(dim)
output = net(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_pynative_int32_fp32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
output = P.GatherDGrad(dim)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_pynative_int64_fp32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
output = P.GatherDGrad(dim)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_pynative_int32_fp16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
output = P.GatherDGrad(dim)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_pynative_int64_fp16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
output = P.GatherDGrad(dim)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)