!3803 add gpu klDivLoss op

Merge pull request !3803 from baihuawei/loss
This commit is contained in:
mindspore-ci-bot 2020-08-03 10:07:34 +08:00 committed by Gitee
commit 49ba473bcc
5 changed files with 316 additions and 0 deletions

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
KLDivLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KLDivLossGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class KLDivLossGpuKernel : public GpuKernel {
public:
KLDivLossGpuKernel() : input_size_(1), reduction_(1) {}
~KLDivLossGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_x = GetDeviceAddress<T>(inputs, 0);
T *input_y = GetDeviceAddress<T>(inputs, 1);
T *loss = GetDeviceAddress<T>(outputs, 0);
KLDivLoss(input_size_, reduction_, input_x, input_y, loss, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
if (reduction_ == 0) {
output_size_list_.push_back(input_size_ * sizeof(T));
} else {
output_size_list_.push_back(sizeof(T));
}
}
private:
size_t input_size_;
int reduction_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(KLDivLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KLDivLossGradGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class KLDivLossGradGpuKernel : public GpuKernel {
public:
KLDivLossGradGpuKernel() : input_size_(1), reduction_(1) {}
~KLDivLossGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_x = GetDeviceAddress<T>(inputs, 0);
T *input_y = GetDeviceAddress<T>(inputs, 1);
T *dloss = GetDeviceAddress<T>(inputs, 2);
T *dx = GetDeviceAddress<T>(outputs, 0);
T *dy = GetDeviceAddress<T>(outputs, 1);
KLDivLossGrad(input_size_, reduction_, input_x, input_y, dloss, dx, dy, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
if (reduction_ == 0) {
input_size_list_.push_back(input_size_ * sizeof(T));
} else {
input_size_list_.push_back(sizeof(T));
}
}
private:
size_t input_size_;
int reduction_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_KL_DIV_LOSS_GRAD_KERNEL_H

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, reduction="none"):
super(Net, self).__init__()
self.KLDivLoss = P.KLDivLoss("none")
def construct(self, x, y):
return self.KLDivLoss(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_binary_cross_entropy_loss():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
target = np.random.rand(20).astype(np.float32)
net = Net()
loss = net(Tensor(prediction), Tensor(target))
expect = [-0.5297444, -0.40738472, -0.5733339, -0.58720195, -0.42922008, -0.31237593,
-0.3332863, -0.78742254, -0.6662671, -0.17546377, -0.31526336, -0.46702948,
-0.23191005, -0.2512708, -0.20934652, -0.32021108, -0.45477402, -0.278453,
-0.5551879, -0.48938933]
assert np.allclose(loss.asnumpy(), expect)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
self.network = network
def construct(self, x1, x2, sens):
gout = self.grad(self.network)(x1, x2, sens)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_binary_cross_entropy_loss_grad():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
target = np.random.rand(20).astype(np.float32)
sens = np.random.rand(20).astype(np.float32)
grad = Grad(Net())
dx = grad(Tensor(prediction), Tensor(target), Tensor(sens))
dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656,
-0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884,
-0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206,
-0.03094601, -0.14319494]
dx2_expect = [0.0163771, -0.950962, -0.03309895, -0.5481312, 0.01523498, 0.39894313,
-0.20858267, -0.27628726, -0.06815486, -0.5134226, 0.46645382, -1.3477919,
-2.409831, 0.65787154, 0.4682768, 0.55671424, -0.04362264, -0.36274382,
0.00852979, -0.03639247]
assert np.allclose(dx[0].asnumpy(), dx1_expect)
assert np.allclose(dx[1].asnumpy(), dx2_expect)