add gpu BCEWithLogitsLoss kernel

This commit is contained in:
TFBunny 2021-04-30 16:46:08 -04:00
parent d3df7ec7b5
commit 9eae68efaa
8 changed files with 594 additions and 0 deletions

View File

@ -0,0 +1,128 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh"
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
template <typename T>
__global__ void FillWithoutBroadcast(const size_t size, const T *src, T *dst) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
dst[pos] = src[pos];
}
return;
}
template <typename T>
__global__ void FillAndBroadcast(const size_t size, const size_t shape_size, const size_t *src_shape,
const size_t *dst_shape, const T *src, T *dst) {
size_t dst_index_array[MAX_LOGITS_DIMENSION];
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t tmp_pos = pos;
size_t pos_size = size / dst_shape[0];
dst_index_array[0] = tmp_pos / pos_size;
for (size_t i = 1; i < shape_size; i++) {
tmp_pos -= dst_index_array[i - 1] * pos_size;
pos_size = pos_size / dst_shape[i];
dst_index_array[i] = tmp_pos / pos_size;
}
size_t src_pos = 0;
size_t src_size = 1;
for (size_t i = 0; i < shape_size; i++) {
src_size *= src_shape[i];
}
for (size_t i = 0; i < shape_size; i++) {
src_size /= src_shape[i];
size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size;
src_pos += length_by_index;
}
dst[pos] = src[src_pos];
}
return;
}
template <typename T>
__global__ void BCEWithLogitsLossMain(size_t size, const T *predict, const T *target, const T *shape_broadcasted,
T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
T max_value = -predict[pos];
max_value = max_value > static_cast<T>(0) ? max_value : static_cast<T>(0);
const T log_weight = (shape_broadcasted[pos] - static_cast<T>(1)) * target[pos] + static_cast<T>(1);
output[pos] = (static_cast<T>(1) - target[pos]) * predict[pos] +
log_weight * (log(exp(-max_value) + exp(-predict[pos] - max_value)) + max_value);
}
return;
}
template <>
__global__ void BCEWithLogitsLossMain(size_t size, const half *predict, const half *target,
const half *shape_broadcasted, half *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
half max_value = -predict[pos];
max_value = max_value > static_cast<half>(0) ? max_value : static_cast<half>(0);
const half log_weight = (shape_broadcasted[pos] - static_cast<half>(1)) * target[pos] + static_cast<half>(1);
output[pos] = (static_cast<half>(1) - target[pos]) * predict[pos] +
log_weight * (hlog(hexp(-max_value) + hexp(-predict[pos] - max_value)) + max_value);
}
return;
}
template <typename T>
__global__ void Mul(size_t size, const T *lhs, T *rhs) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
rhs[pos] *= lhs[pos];
}
return;
}
template <typename T>
void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape,
const size_t shape_size, const T *weight, const size_t *weight_shape,
const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape,
const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output,
cudaStream_t cuda_stream) {
if (pos_weight_need_broadcast) {
FillAndBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(
input_size, shape_size, pos_weight_shape, input_shape, pos_weight, shape_broadcasted);
} else {
FillWithoutBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, pos_weight,
shape_broadcasted);
}
BCEWithLogitsLossMain<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, predict, target,
shape_broadcasted, output);
if (weight_need_broadcast) {
FillAndBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, shape_size, weight_shape,
input_shape, weight, shape_broadcasted);
} else {
FillWithoutBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, weight,
shape_broadcasted);
}
Mul<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, shape_broadcasted, output);
return;
}
template void CalBCEWithLogitsLoss<half>(const size_t input_size, const half *predict, const half *target,
const size_t *input_shape, const size_t shape_size, const half *weight,
const size_t *weight_shape, const bool weight_need_broadcast,
const half *pos_weight, const size_t *pos_weight_shape,
const bool pos_weight_need_broadcast, half *shape_broadcasted, half *output,
cudaStream_t cuda_stream);
template void CalBCEWithLogitsLoss<float>(const size_t input_size, const float *predict, const float *target,
const size_t *input_shape, const size_t shape_size, const float *weight,
const size_t *weight_shape, const bool weight_need_broadcast,
const float *pos_weight, const size_t *pos_weight_shape,
const bool pos_weight_need_broadcast, float *shape_broadcasted, float *output,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_
#define MAX_LOGITS_DIMENSION 100
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape,
const size_t shape_size, const T *weight, const size_t *weight_shape,
const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape,
const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BCEWithLogitsLossKernel, float)
MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BCEWithLogitsLossKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,167 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BCEWithLogitsLossKernel : public GpuKernel {
public:
BCEWithLogitsLossKernel() { ResetResource(); }
~BCEWithLogitsLossKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *predict = GetDeviceAddress<T>(inputs, 0);
T *target = GetDeviceAddress<T>(inputs, 1);
T *weight = GetDeviceAddress<T>(inputs, 2);
T *pos_weight = GetDeviceAddress<T>(inputs, 3);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *weight_shape = GetDeviceAddress<size_t>(workspace, 1);
size_t *pos_weight_shape = GetDeviceAddress<size_t>(workspace, 2);
T *shape_broadcasted = GetDeviceAddress<T>(workspace, 3);
T *output = GetDeviceAddress<T>(outputs, 0);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_shape, &input_shape_[0], input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(weight_shape, &weight_shape_[0], weight_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync weight_shape_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
cudaMemcpyAsync(pos_weight_shape, &pos_weight_shape_[0], pos_weight_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync pos_weight_shape_ failed");
CalBCEWithLogitsLoss(input_size_, predict, target, input_shape, input_shape_.size(), weight, weight_shape,
weight_need_broadcast_, pos_weight, pos_weight_shape, pos_weight_need_broadcast_,
shape_broadcasted, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but BCEWithLogitsLoss needs 4 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but BCEWithLogitsLoss has 1 output.";
return false;
}
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_size_ = 1;
if (input_shape_.size() > MAX_LOGITS_DIMENSION) {
MS_LOG(EXCEPTION) << "Input dimension is " << input_shape_.size()
<< ", but BCEWithLogitsLoss can only support up to " << MAX_LOGITS_DIMENSION << "-D.";
return false;
}
for (size_t i = 0; i < input_shape_.size(); i++) {
input_size_ *= input_shape_[i];
}
// weight shape
weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
weight_size_ = 1;
for (size_t i = 0; i < weight_shape_.size(); i++) {
weight_size_ *= weight_shape_[i];
}
weight_need_broadcast_ = NeedBroadcast(&weight_shape_, input_shape_);
// pos_weight shape
pos_weight_size_ = 1;
pos_weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
for (size_t i = 0; i < pos_weight_shape_.size(); i++) {
pos_weight_size_ *= pos_weight_shape_[i];
}
pos_weight_need_broadcast_ = NeedBroadcast(&pos_weight_shape_, input_shape_);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
weight_size_ = 1;
pos_weight_size_ = 1;
weight_need_broadcast_ = false;
pos_weight_need_broadcast_ = false;
input_shape_.clear();
weight_shape_.clear();
pos_weight_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(weight_size_ * sizeof(T));
input_size_list_.push_back(pos_weight_size_ * sizeof(T));
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
workspace_size_list_.push_back(weight_shape_.size() * sizeof(size_t));
workspace_size_list_.push_back(pos_weight_shape_.size() * sizeof(size_t));
// extra space for holding extra array shape of input, for broadcasted
// weight and pos_weight
workspace_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
}
private:
bool NeedBroadcast(std::vector<size_t> *shape, const std::vector<size_t> &result_shape) {
// result_shape is larger that shape
// and shape is able to broadcasted to result_shape
if (shape->size() != result_shape.size()) {
size_t fill_size = result_shape.size() - shape->size();
(void)shape->insert(shape->begin(), fill_size, 1);
return true;
}
for (size_t i = 0; i < result_shape.size(); i++) {
if (shape->at(i) != result_shape[i]) {
return true;
}
}
return false;
}
size_t input_size_;
size_t weight_size_;
size_t pos_weight_size_;
bool weight_need_broadcast_;
bool pos_weight_need_broadcast_;
std::vector<size_t> input_shape_;
std::vector<size_t> weight_shape_;
std::vector<size_t> pos_weight_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_

View File

@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))};
(void)node_inputs.insert(node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(node_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
auto predict_input = cnode->inputs()[1];
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)};
AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get());
// Add reduce node
string reduction = AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction);
MS_LOG(INFO) << "Create reduce node for BCEWithLogitsLoss, reduction attr is: " << reduction;
std::vector<AnfNodePtr> reduce_inputs;
if (reduction == "sum") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode};
} else if (reduction == "mean") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode};
} else {
MS_LOG(INFO) << "Reduction is none, no optimization on current BCEWithLogitsLoss.";
return nullptr;
}
auto reduce_node = func_graph->NewCNode(reduce_inputs);
MS_EXCEPTION_IF_NULL(reduce_node);
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
auto shape = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node);
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
reduce_node->set_scope(cnode->scope());
return reduce_node;
}
const BaseRef BCEWithLogitsLossFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs});
}
const AnfNodePtr BCEWithLogitsLossFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (GetBoolAttr(cnode, kAttrVisited)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
if (cnode->inputs().size() == 0) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr("reduction", cnode)) {
MS_LOG(INFO) << "Primitive BCEWithLogitsLoss doesn't not have reduction attr.";
return nullptr;
}
return AddReduceNode(func_graph, node);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class BCEWithLogitsLossFusion : public PatternProcessPass {
public:
explicit BCEWithLogitsLossFusion(bool multigraph = true)
: PatternProcessPass("bce_with_logits_loss_fusion", multigraph) {}
~BCEWithLogitsLossFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_

View File

@ -38,6 +38,7 @@
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "backend/optimizer/gpu/print_reduce_fusion.h"
#include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h"
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
@ -143,6 +144,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, reduction):
super(Net, self).__init__()
self.loss = P.BCEWithLogitsLoss(reduction=reduction)
def construct(self, predict, target, weight, pos_weight):
return self.loss(predict, target, weight, pos_weight)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduction_none_testcases():
# fp32 + both modes
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
loss = Net("none")
predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
output = loss(predict, target, weight, pos_weight)
expected = np.array([[0.6111006, 0.5032824, 0.26318598],
[0.58439666, 0.55301523, -0.436814]]).astype(np.float32)
np.testing.assert_almost_equal(expected, output.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
loss = Net("none")
predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
output = loss(predict, target, weight, pos_weight)
expected = np.array([[0.6111006, 0.5032824, 0.26318598],
[0.58439666, 0.55301523, -0.436814]])
np.testing.assert_almost_equal(expected, output.asnumpy())
# fp16
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
loss = Net("none")
predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float16))
target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float16))
weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16))
pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16))
output = loss(predict, target, weight, pos_weight)
expected = np.array([[0.611, 0.503, 0.2627],
[0.584, 0.5527, -0.437]]).astype(np.float16)
np.testing.assert_almost_equal(expected, output.asnumpy(), decimal=3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduction_mean_testcases():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
loss = Net("mean")
predict = Tensor(np.arange(6).reshape(2, 3).astype(np.float32))
target = Tensor(np.arange(34, 40).reshape(2, 3).astype(np.float32))
weight = Tensor(np.array([2, 3, 1]).astype(np.float32))
pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32))
output = loss(predict, target, weight, pos_weight)
expected = -113.55404
# assert scalar
assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduction_sum_testcases():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
loss = Net("sum")
predict = Tensor(np.arange(6, 12).reshape(2, 3).astype(np.float32))
target = Tensor(np.arange(6).reshape(2, 3).astype(np.float32))
weight = Tensor(np.array([3, 3, 4]).astype(np.float32))
pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32))
output = loss(predict, target, weight, pos_weight)
expected = -333.96677
# assert scalar
assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001)