diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc new file mode 100644 index 00000000000..20cb060ea06 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + KLDivLoss, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KLDivLossGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h new file mode 100644 index 00000000000..43aced94941 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class KLDivLossGpuKernel : public GpuKernel { + public: + KLDivLossGpuKernel() : input_size_(1), reduction_(1) {} + ~KLDivLossGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *loss = GetDeviceAddress(outputs, 0); + + KLDivLoss(input_size_, reduction_, input_x, input_y, loss, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + output_size_list_.push_back(input_size_ * sizeof(T)); + } else { + output_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc new file mode 100644 index 00000000000..83371f580b0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(KLDivLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KLDivLossGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h new file mode 100644 index 00000000000..37a0c76a8c3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class KLDivLossGradGpuKernel : public GpuKernel { + public: + KLDivLossGradGpuKernel() : input_size_(1), reduction_(1) {} + ~KLDivLossGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *dloss = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + T *dy = GetDeviceAddress(outputs, 1); + KLDivLossGrad(input_size_, reduction_, input_x, input_y, dloss, dx, dy, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + string reduction = GetAttr(kernel_node, "reduction"); + if (reduction == "none") { + reduction_ = 0; + } else if (reduction == "sum") { + reduction_ = 2; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + if (reduction_ == 0) { + input_size_list_.push_back(input_size_ * sizeof(T)); + } else { + input_size_list_.push_back(sizeof(T)); + } + } + + private: + size_t input_size_; + int reduction_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_KL_DIV_LOSS_GRAD_KERNEL_H diff --git a/tests/st/ops/gpu/test_kl_div_op.py b/tests/st/ops/gpu/test_kl_div_op.py new file mode 100644 index 00000000000..e5b8fcd0799 --- /dev/null +++ b/tests/st/ops/gpu/test_kl_div_op.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.KLDivLoss = P.KLDivLoss("none") + + def construct(self, x, y): + return self.KLDivLoss(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + net = Net() + loss = net(Tensor(prediction), Tensor(target)) + expect = [-0.5297444, -0.40738472, -0.5733339, -0.58720195, -0.42922008, -0.31237593, + -0.3332863, -0.78742254, -0.6662671, -0.17546377, -0.31526336, -0.46702948, + -0.23191005, -0.2512708, -0.20934652, -0.32021108, -0.45477402, -0.278453, + -0.5551879, -0.48938933] + assert np.allclose(loss.asnumpy(), expect) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens): + gout = self.grad(self.network)(x1, x2, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + + dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, + -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, + -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, + -0.03094601, -0.14319494] + + dx2_expect = [0.0163771, -0.950962, -0.03309895, -0.5481312, 0.01523498, 0.39894313, + -0.20858267, -0.27628726, -0.06815486, -0.5134226, 0.46645382, -1.3477919, + -2.409831, 0.65787154, 0.4682768, 0.55671424, -0.04362264, -0.36274382, + 0.00852979, -0.03639247] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + assert np.allclose(dx[1].asnumpy(), dx2_expect)