diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cu new file mode 100644 index 00000000000..7dc2528eeb3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cu @@ -0,0 +1,128 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh" + +__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } + +template +__global__ void FillWithoutBroadcast(const size_t size, const T *src, T *dst) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + dst[pos] = src[pos]; + } + return; +} + +template +__global__ void FillAndBroadcast(const size_t size, const size_t shape_size, const size_t *src_shape, + const size_t *dst_shape, const T *src, T *dst) { + size_t dst_index_array[MAX_LOGITS_DIMENSION]; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t tmp_pos = pos; + size_t pos_size = size / dst_shape[0]; + dst_index_array[0] = tmp_pos / pos_size; + for (size_t i = 1; i < shape_size; i++) { + tmp_pos -= dst_index_array[i - 1] * pos_size; + pos_size = pos_size / dst_shape[i]; + dst_index_array[i] = tmp_pos / pos_size; + } + size_t src_pos = 0; + size_t src_size = 1; + for (size_t i = 0; i < shape_size; i++) { + src_size *= src_shape[i]; + } + for (size_t i = 0; i < shape_size; i++) { + src_size /= src_shape[i]; + size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size; + src_pos += length_by_index; + } + dst[pos] = src[src_pos]; + } + return; +} + +template +__global__ void BCEWithLogitsLossMain(size_t size, const T *predict, const T *target, const T *shape_broadcasted, + T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T max_value = -predict[pos]; + max_value = max_value > static_cast(0) ? max_value : static_cast(0); + const T log_weight = (shape_broadcasted[pos] - static_cast(1)) * target[pos] + static_cast(1); + output[pos] = (static_cast(1) - target[pos]) * predict[pos] + + log_weight * (log(exp(-max_value) + exp(-predict[pos] - max_value)) + max_value); + } + return; +} + +template <> +__global__ void BCEWithLogitsLossMain(size_t size, const half *predict, const half *target, + const half *shape_broadcasted, half *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half max_value = -predict[pos]; + max_value = max_value > static_cast(0) ? max_value : static_cast(0); + const half log_weight = (shape_broadcasted[pos] - static_cast(1)) * target[pos] + static_cast(1); + output[pos] = (static_cast(1) - target[pos]) * predict[pos] + + log_weight * (hlog(hexp(-max_value) + hexp(-predict[pos] - max_value)) + max_value); + } + return; +} + +template +__global__ void Mul(size_t size, const T *lhs, T *rhs) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + rhs[pos] *= lhs[pos]; + } + return; +} + +template +void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape, + const size_t shape_size, const T *weight, const size_t *weight_shape, + const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape, + const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output, + cudaStream_t cuda_stream) { + if (pos_weight_need_broadcast) { + FillAndBroadcast<<>>( + input_size, shape_size, pos_weight_shape, input_shape, pos_weight, shape_broadcasted); + } else { + FillWithoutBroadcast<<>>(input_size, pos_weight, + shape_broadcasted); + } + BCEWithLogitsLossMain<<>>(input_size, predict, target, + shape_broadcasted, output); + if (weight_need_broadcast) { + FillAndBroadcast<<>>(input_size, shape_size, weight_shape, + input_shape, weight, shape_broadcasted); + } else { + FillWithoutBroadcast<<>>(input_size, weight, + shape_broadcasted); + } + Mul<<>>(input_size, shape_broadcasted, output); + return; +} + +template void CalBCEWithLogitsLoss(const size_t input_size, const half *predict, const half *target, + const size_t *input_shape, const size_t shape_size, const half *weight, + const size_t *weight_shape, const bool weight_need_broadcast, + const half *pos_weight, const size_t *pos_weight_shape, + const bool pos_weight_need_broadcast, half *shape_broadcasted, half *output, + cudaStream_t cuda_stream); +template void CalBCEWithLogitsLoss(const size_t input_size, const float *predict, const float *target, + const size_t *input_shape, const size_t shape_size, const float *weight, + const size_t *weight_shape, const bool weight_need_broadcast, + const float *pos_weight, const size_t *pos_weight_shape, + const bool pos_weight_need_broadcast, float *shape_broadcasted, float *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh new file mode 100644 index 00000000000..87f3c1ac58d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ + +#define MAX_LOGITS_DIMENSION 100 +#include "runtime/device/gpu/cuda_common.h" + +template +void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape, + const size_t shape_size, const T *weight, const size_t *weight_shape, + const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape, + const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.cc new file mode 100644 index 00000000000..c5bae2dbcab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BCEWithLogitsLossKernel, float) +MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + BCEWithLogitsLossKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h new file mode 100644 index 00000000000..85aadfcde5a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h @@ -0,0 +1,167 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BCEWithLogitsLossKernel : public GpuKernel { + public: + BCEWithLogitsLossKernel() { ResetResource(); } + ~BCEWithLogitsLossKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *predict = GetDeviceAddress(inputs, 0); + T *target = GetDeviceAddress(inputs, 1); + T *weight = GetDeviceAddress(inputs, 2); + T *pos_weight = GetDeviceAddress(inputs, 3); + size_t *input_shape = GetDeviceAddress(workspace, 0); + size_t *weight_shape = GetDeviceAddress(workspace, 1); + size_t *pos_weight_shape = GetDeviceAddress(workspace, 2); + T *shape_broadcasted = GetDeviceAddress(workspace, 3); + T *output = GetDeviceAddress(outputs, 0); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(input_shape, &input_shape_[0], input_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape_ failed"); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(weight_shape, &weight_shape_[0], weight_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync weight_shape_ failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + kernel_node_, + cudaMemcpyAsync(pos_weight_shape, &pos_weight_shape_[0], pos_weight_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync pos_weight_shape_ failed"); + CalBCEWithLogitsLoss(input_size_, predict, target, input_shape, input_shape_.size(), weight, weight_shape, + weight_need_broadcast_, pos_weight, pos_weight_shape, pos_weight_need_broadcast_, + shape_broadcasted, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but BCEWithLogitsLoss needs 4 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but BCEWithLogitsLoss has 1 output."; + return false; + } + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + if (input_shape_.size() > MAX_LOGITS_DIMENSION) { + MS_LOG(EXCEPTION) << "Input dimension is " << input_shape_.size() + << ", but BCEWithLogitsLoss can only support up to " << MAX_LOGITS_DIMENSION << "-D."; + return false; + } + for (size_t i = 0; i < input_shape_.size(); i++) { + input_size_ *= input_shape_[i]; + } + // weight shape + weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + weight_size_ = 1; + for (size_t i = 0; i < weight_shape_.size(); i++) { + weight_size_ *= weight_shape_[i]; + } + weight_need_broadcast_ = NeedBroadcast(&weight_shape_, input_shape_); + // pos_weight shape + pos_weight_size_ = 1; + pos_weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < pos_weight_shape_.size(); i++) { + pos_weight_size_ *= pos_weight_shape_[i]; + } + pos_weight_need_broadcast_ = NeedBroadcast(&pos_weight_shape_, input_shape_); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 1; + weight_size_ = 1; + pos_weight_size_ = 1; + weight_need_broadcast_ = false; + pos_weight_need_broadcast_ = false; + input_shape_.clear(); + weight_shape_.clear(); + pos_weight_shape_.clear(); + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(weight_size_ * sizeof(T)); + input_size_list_.push_back(pos_weight_size_ * sizeof(T)); + workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t)); + workspace_size_list_.push_back(weight_shape_.size() * sizeof(size_t)); + workspace_size_list_.push_back(pos_weight_shape_.size() * sizeof(size_t)); + // extra space for holding extra array shape of input, for broadcasted + // weight and pos_weight + workspace_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + bool NeedBroadcast(std::vector *shape, const std::vector &result_shape) { + // result_shape is larger that shape + // and shape is able to broadcasted to result_shape + if (shape->size() != result_shape.size()) { + size_t fill_size = result_shape.size() - shape->size(); + (void)shape->insert(shape->begin(), fill_size, 1); + return true; + } + for (size_t i = 0; i < result_shape.size(); i++) { + if (shape->at(i) != result_shape[i]) { + return true; + } + } + return false; + } + + size_t input_size_; + size_t weight_size_; + size_t pos_weight_size_; + bool weight_need_broadcast_; + bool pos_weight_need_broadcast_; + std::vector input_shape_; + std::vector weight_shape_; + std::vector pos_weight_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.cc new file mode 100644 index 00000000000..e73821f4d5f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/gpu/kernel_info_setter.h" + +namespace mindspore { +namespace opt { +AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector node_inputs = { + NewValueNode(std::make_shared(prim::kPrimBCEWithLogitsLoss->name()))}; + (void)node_inputs.insert(node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + CNodePtr new_cnode = func_graph->NewCNode(node_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + auto predict_input = cnode->inputs()[1]; + auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)}; + auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get()); + + // Add reduce node + string reduction = AnfAlgo::GetNodeAttr(node, kAttrReduction); + MS_LOG(INFO) << "Create reduce node for BCEWithLogitsLoss, reduction attr is: " << reduction; + std::vector reduce_inputs; + if (reduction == "sum") { + reduce_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), new_cnode}; + } else if (reduction == "mean") { + reduce_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceMean->name())), new_cnode}; + } else { + MS_LOG(INFO) << "Reduction is none, no optimization on current BCEWithLogitsLoss."; + return nullptr; + } + auto reduce_node = func_graph->NewCNode(reduce_inputs); + MS_EXCEPTION_IF_NULL(reduce_node); + auto type = AnfAlgo::GetOutputInferDataType(node, 0); + auto shape = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{}), reduce_node); + AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node); + reduce_node->set_scope(cnode->scope()); + return reduce_node; +} + +const BaseRef BCEWithLogitsLossFusion::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs}); +} + +const AnfNodePtr BCEWithLogitsLossFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (GetBoolAttr(cnode, kAttrVisited)) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + if (cnode->inputs().size() == 0) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr("reduction", cnode)) { + MS_LOG(INFO) << "Primitive BCEWithLogitsLoss doesn't not have reduction attr."; + return nullptr; + } + return AddReduceNode(func_graph, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.h new file mode 100644 index 00000000000..93e2a6aad35 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/bce_with_logits_loss_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BCEWithLogitsLossFusion : public PatternProcessPass { + public: + explicit BCEWithLogitsLossFusion(bool multigraph = true) + : PatternProcessPass("bce_with_logits_loss_fusion", multigraph) {} + ~BCEWithLogitsLossFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 0b19137fb64..132a67cd61f 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -38,6 +38,7 @@ #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" #include "backend/optimizer/gpu/replace_addn_fusion.h" #include "backend/optimizer/gpu/print_reduce_fusion.h" +#include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h" #include "backend/optimizer/gpu/remove_format_transform_pair.h" #include "backend/optimizer/gpu/remove_redundant_format_transform.h" #include "backend/optimizer/gpu/reduce_precision_fusion.h" @@ -143,6 +144,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared("print_reduce")); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/tests/st/ops/gpu/test_bce_with_logits_loss.py b/tests/st/ops/gpu/test_bce_with_logits_loss.py new file mode 100644 index 00000000000..a66da7a09a4 --- /dev/null +++ b/tests/st/ops/gpu/test_bce_with_logits_loss.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import math +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, reduction): + super(Net, self).__init__() + self.loss = P.BCEWithLogitsLoss(reduction=reduction) + + def construct(self, predict, target, weight, pos_weight): + return self.loss(predict, target, weight, pos_weight) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reduction_none_testcases(): + # fp32 + both modes + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + loss = Net("none") + predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) + target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) + weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + output = loss(predict, target, weight, pos_weight) + expected = np.array([[0.6111006, 0.5032824, 0.26318598], + [0.58439666, 0.55301523, -0.436814]]).astype(np.float32) + np.testing.assert_almost_equal(expected, output.asnumpy()) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + loss = Net("none") + predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) + target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) + weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) + output = loss(predict, target, weight, pos_weight) + expected = np.array([[0.6111006, 0.5032824, 0.26318598], + [0.58439666, 0.55301523, -0.436814]]) + np.testing.assert_almost_equal(expected, output.asnumpy()) + # fp16 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + loss = Net("none") + predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float16)) + target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float16)) + weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16)) + pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16)) + output = loss(predict, target, weight, pos_weight) + expected = np.array([[0.611, 0.503, 0.2627], + [0.584, 0.5527, -0.437]]).astype(np.float16) + np.testing.assert_almost_equal(expected, output.asnumpy(), decimal=3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reduction_mean_testcases(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + loss = Net("mean") + predict = Tensor(np.arange(6).reshape(2, 3).astype(np.float32)) + target = Tensor(np.arange(34, 40).reshape(2, 3).astype(np.float32)) + weight = Tensor(np.array([2, 3, 1]).astype(np.float32)) + pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32)) + output = loss(predict, target, weight, pos_weight) + expected = -113.55404 + # assert scalar + assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reduction_sum_testcases(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + loss = Net("sum") + predict = Tensor(np.arange(6, 12).reshape(2, 3).astype(np.float32)) + target = Tensor(np.arange(6).reshape(2, 3).astype(np.float32)) + weight = Tensor(np.array([3, 3, 4]).astype(np.float32)) + pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32)) + output = loss(predict, target, weight, pos_weight) + expected = -333.96677 + # assert scalar + assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001)