!35842 refactor broadcast_gpu_kernel to support minimum grad grad

Merge pull request !35842 from yangsijia/minimum_grad_grad_refactor
This commit is contained in:
i-robot 2022-06-20 07:11:04 +00:00 committed by Gitee
commit daec6916ca
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 346 additions and 612 deletions

View File

@ -1357,6 +1357,10 @@ bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs)
bool AlignedBroadCastShape(size_t align_rank, std::vector<size_t> *broadcast, std::vector<size_t> *lhs,
std::vector<size_t> *rhs) {
if (broadcast == nullptr || lhs == nullptr || rhs == nullptr) {
MS_LOG(ERROR) << "input is nullptr.";
return false;
}
size_t broadcast_rank = broadcast->size();
size_t l_rank = lhs->size();
size_t r_rank = rhs->size();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -175,7 +175,6 @@ template <typename T>
__global__ void KLDivLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x,
const T *input_y, const T *dloss, T *dx) {
T epsilon = 1e-6;
T one = static_cast<T>(1);
if (reduction == ReductionMode::kNone) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_y[i], epsilon);

View File

@ -16,465 +16,196 @@
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
#include <iostream>
namespace mindspore {
namespace kernel {
// fp64
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, double)
bool BroadcastOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
if (!GetOpType()) {
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
// fp32
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
DivNoNan,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
NotEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
TruncateDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
TruncateMod,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernelMod, float)
kernel_func_ = func_list_[index].second;
return true;
}
// fp16
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
DivNoNan,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
NotEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
TruncateDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
TruncateMod,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernelMod, half)
bool BroadcastOpGpuKernelMod::GetOpType() {
auto iter = kBroadcastCmpTypeMap.find(kernel_name_);
if (iter != kBroadcastCmpTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = true;
return true;
}
// int32
MS_REG_GPU_KERNEL_ONE(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
TruncateDiv,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
TruncateMod,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernelMod, int)
iter = kBroadcastArithmetricTypeMap.find(kernel_name_);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = false;
return true;
}
// int64
MS_REG_GPU_KERNEL_ONE(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int64_t)
MS_LOG(ERROR) << "For 'BroadcastGpuOp', it only support these types: " << GetValidKernelTypes()
<< " currently, but got " << kernel_name_;
return false;
}
// int8
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
TruncateDiv, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(
TruncateMod, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernelMod, int8_t)
std::string BroadcastOpGpuKernelMod::GetValidKernelTypes() {
std::ostringstream valid_types;
valid_types << "Valid Compare Types: ";
std::for_each(kBroadcastCmpTypeMap.cbegin(), kBroadcastCmpTypeMap.cend(),
[&valid_types](const std::map<std::string, BroadcastOpType>::value_type &p) {
valid_types << p.first << std::string(", ");
});
valid_types << "; Valid Arithmetric Types: ";
std::for_each(kBroadcastArithmetricTypeMap.cbegin(), kBroadcastArithmetricTypeMap.cend(),
[&valid_types](const std::map<std::string, BroadcastOpType>::value_type &p) {
valid_types << p.first << std::string(", ");
});
return valid_types.str();
}
// uint32
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernelMod, uint)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernelMod, uint)
int BroadcastOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
// uint8
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
TruncateDiv,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
TruncateMod,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernelMod, uint8_t)
lhs_shape_ = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
rhs_shape_ = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
output_shape_ = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector());
output_num_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
is_null_input_ = CHECK_SHAPE_NULL(lhs_shape_, kernel_name_, "input_0") ||
CHECK_SHAPE_NULL(rhs_shape_, kernel_name_, "input_1") ||
CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output_0");
if (is_null_input_) {
return KRET_OK;
}
need_broadcast_ = broadcast_utils::IsBroadcast(lhs_shape_, rhs_shape_);
if (!broadcast_utils::AlignedBroadCastShape(MAX_DIMS, &output_shape_, &lhs_shape_, &rhs_shape_)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than " << MAX_DIMS
<< ", but got " << lhs_shape_.size();
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
// int16
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
BroadcastOpGpuKernelMod, int16_t)
template <typename T>
bool BroadcastOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto lhs = GetDeviceAddress<T>(inputs, kIndex0);
auto rhs = GetDeviceAddress<T>(inputs, kIndex1);
// uint16
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
BroadcastOpGpuKernelMod, uint16_t)
if (is_comp_op_) {
bool *output = GetDeviceAddress<bool>(outputs, kIndex0);
if (need_broadcast_) {
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
} else {
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
}
} else {
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
}
}
return true;
}
// uint32
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernelMod, uint32_t)
std::vector<KernelAttr> BroadcastOpGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair.first; });
return support_list;
}
// uint64
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
BroadcastOpGpuKernelMod, uint64_t)
std::vector<std::pair<KernelAttr, BroadcastOpGpuKernelMod::BroadCastFunc>> BroadcastOpGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&BroadcastOpGpuKernelMod::LaunchKernel<double>}};
// bool
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
MS_REG_GPU_KERNEL_ONE(
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
MS_REG_GPU_KERNEL_ONE(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
// uint64
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Add, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Atan2, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AbsGrad, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Div, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DivNoNan, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Equal, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorMod, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Greater, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GreaterEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Less, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LessEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalOr, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalAnd, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mul, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mod, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Minimum, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Maximum, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NotEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Pow, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RealDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sub, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateMod, BroadcastOpGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,6 +21,10 @@
#include <vector>
#include <string>
#include <map>
#include <functional>
#include <utility>
#include <algorithm>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
@ -31,195 +35,83 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T>
class BroadcastOpGpuKernelMod : public DeprecatedNativeGpuKernelMod {
static const std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER},
{"Less", BROADCAST_TYPE_LESS},
{"Equal", BROADCAST_TYPE_EQUAL},
{"GreaterEqual", BROADCAST_TYPE_GREATER_EQUAL},
{"LessEqual", BROADCAST_TYPE_LESS_EQUAL},
{"NotEqual", BROADCAST_TYPE_NOT_EQUAL},
{"LogicalAnd", BROADCAST_TYPE_LOGICAL_AND},
{"LogicalOr", BROADCAST_TYPE_LOGICAL_OR},
};
static const std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM},
{"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL},
{"Sub", BROADCAST_TYPE_SUB},
{"Add", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV},
{"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV},
{"DivNoNan", BROADCAST_TYPE_DIVNONAN},
{"Mod", BROADCAST_TYPE_MOD},
{"FloorMod", BROADCAST_TYPE_FLOORMOD},
{"Atan2", BROADCAST_TYPE_ATAN2},
{"TruncateDiv", BROADCAST_TYPE_TRUNCATEDIV},
{"TruncateMod", BROADCAST_TYPE_TRUNCATEMOD},
};
class BroadcastOpGpuKernelMod : public NativeGpuKernelMod {
public:
BroadcastOpGpuKernelMod() { ResetResource(); }
BroadcastOpGpuKernelMod() {}
~BroadcastOpGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
if (is_comp_op_) {
bool *output = GetDeviceAddress<bool>(outputs, 0);
if (need_broadcast_) {
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
} else {
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
return true;
cuda_stream_ = reinterpret_cast<cudaStream_t>(cuda_stream);
return kernel_func_(this, inputs, outputs);
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
GetOpType(kernel_node);
auto shape1 = Convert2SizeTClipNeg(AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0));
auto shape2 = Convert2SizeTClipNeg(AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1));
auto shape3 = Convert2SizeTClipNeg(AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0));
is_null_input_ = CHECK_SHAPE_NULL(shape1, kernel_name_, "input") ||
CHECK_SHAPE_NULL(shape2, kernel_name_, "input") ||
CHECK_SHAPE_NULL(shape3, kernel_name_, "output");
if (is_null_input_) {
InitSizeLists();
return true;
}
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
<< ", but got " << shape1.size();
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
for (size_t i = 0; i < shape3.size(); i++) {
if (need_broadcast_) {
if (i < MAX_DIMS) {
output_shape_[i] = shape3[i];
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of output should be less than " << MAX_DIMS
<< ", but got " << i;
}
}
output_num_ *= shape3[i];
}
int lhs_offset = shape3.size() - shape1.size();
for (size_t j = 0; j < shape1.size(); j++) {
if (need_broadcast_) {
if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) {
lhs_shape_[j + lhs_offset] = shape1[j];
} else {
auto index = j + lhs_offset;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of input cannot be " << index << ", but got "
<< index;
}
}
input1_num_ *= shape1[j];
}
int rhs_offset = shape3.size() - shape2.size();
for (size_t k = 0; k < shape2.size(); k++) {
if (need_broadcast_) {
if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) {
rhs_shape_[k + rhs_offset] = shape2[k];
} else {
auto index = k + rhs_offset;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of input cannot be " << index << ", but got "
<< index;
}
}
input2_num_ *= shape2[k];
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
op_type_ = BROADCAST_TYPE_INVALID;
need_broadcast_ = false;
is_comp_op_ = false;
is_null_input_ = false;
input1_num_ = 1;
input2_num_ = 1;
output_num_ = 1;
lhs_shape_.clear();
rhs_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void InitResource() override { return; }
void InitSizeLists() override {
input_size_list_.push_back(input1_num_ * sizeof(T));
input_size_list_.push_back(input2_num_ * sizeof(T));
auto unit_size = is_comp_op_ ? sizeof(bool) : sizeof(T);
output_size_list_.push_back(output_num_ * unit_size);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
bool GetOpType();
static const std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER},
{"Less", BROADCAST_TYPE_LESS},
{"Equal", BROADCAST_TYPE_EQUAL},
{"GreaterEqual", BROADCAST_TYPE_GREATER_EQUAL},
{"LessEqual", BROADCAST_TYPE_LESS_EQUAL},
{"NotEqual", BROADCAST_TYPE_NOT_EQUAL},
{"LogicalAnd", BROADCAST_TYPE_LOGICAL_AND},
{"LogicalOr", BROADCAST_TYPE_LOGICAL_OR},
};
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
auto iter = kBroadcastCmpTypeMap.find(kernel_name);
if (iter != kBroadcastCmpTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = true;
return;
}
using BroadCastFunc = std::function<bool(BroadcastOpGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static const std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM},
{"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL},
{"Sub", BROADCAST_TYPE_SUB},
{"Add", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV},
{"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV},
{"DivNoNan", BROADCAST_TYPE_DIVNONAN},
{"Mod", BROADCAST_TYPE_MOD},
{"FloorMod", BROADCAST_TYPE_FLOORMOD},
{"Atan2", BROADCAST_TYPE_ATAN2},
{"TruncateDiv", BROADCAST_TYPE_TRUNCATEDIV},
{"TruncateMod", BROADCAST_TYPE_TRUNCATEMOD},
};
iter = kBroadcastArithmetricTypeMap.find(kernel_name);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = false;
return;
}
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< ", only support these types: Maximum, Minimum, Pow, RealDiv, Mul, Sub, Add, Div, DivNoNan, "
"Mod, FloorDiv, AbsGrad, FloorMod, Atan2, TruncateDiv or TruncateMod currently, but got "
<< kernel_name;
}
std::string GetValidKernelTypes();
BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;
bool is_null_input_;
size_t input1_num_;
size_t input2_num_;
size_t output_num_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
size_t unit_size_{1};
size_t output_num_{1};
cudaStream_t cuda_stream_{nullptr};
BroadCastFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, BroadCastFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/atan2.h"
#include <string>
#include <algorithm>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr Atan2InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(prim_name, input_args);
}
TypePtr Atan2InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto op_name = prim->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
op_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(Atan2, BaseOperator);
AbstractBasePtr Atan2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto base_type = Atan2InferType(primitive, input_args);
auto base_shape = Atan2InferShape(primitive, input_args);
return abstract::MakeAbstract(base_shape, base_type);
}
REGISTER_PRIMITIVE_C(kNameAtan2, Atan2);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ATAN2_H_
#define MINDSPORE_CORE_OPS_ATAN2_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAtan2 = "Atan2";
/// \brief Calculate the angle in the Euclidean plane.
class MIND_API Atan2 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Atan2);
/// \brief Constructor.
Atan2() : BaseOperator(kNameAtan2) { InitIOName({"x", "y"}, {"output"}); }
explicit Atan2(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "y"}, {"output"}); }
void Init() const {}
};
abstract::AbstractBasePtr Atan2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ATAN2_H_