|
|
|
@ -16,465 +16,196 @@
|
|
|
|
|
|
|
|
|
|
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
|
|
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
// fp64
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Greater,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Minimum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Maximum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
AbsGrad,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
RealDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Atan2,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, double)
|
|
|
|
|
bool BroadcastOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
|
|
|
|
const std::vector<KernelTensorPtr> &outputs) {
|
|
|
|
|
kernel_name_ = base_operator->name();
|
|
|
|
|
if (inputs.empty() || outputs.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (!GetOpType()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
|
|
|
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
|
|
|
|
if (!is_match) {
|
|
|
|
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// fp32
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Greater,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Maximum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Minimum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
RealDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
AbsGrad,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Atan2,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
BroadcastOpGpuKernelMod, float)
|
|
|
|
|
kernel_func_ = func_list_[index].second;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// fp16
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Greater,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Maximum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Minimum,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
RealDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
AbsGrad,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Atan2,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
BroadcastOpGpuKernelMod, half)
|
|
|
|
|
bool BroadcastOpGpuKernelMod::GetOpType() {
|
|
|
|
|
auto iter = kBroadcastCmpTypeMap.find(kernel_name_);
|
|
|
|
|
if (iter != kBroadcastCmpTypeMap.end()) {
|
|
|
|
|
op_type_ = iter->second;
|
|
|
|
|
is_comp_op_ = true;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// int32
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, int)
|
|
|
|
|
iter = kBroadcastArithmetricTypeMap.find(kernel_name_);
|
|
|
|
|
if (iter != kBroadcastArithmetricTypeMap.end()) {
|
|
|
|
|
op_type_ = iter->second;
|
|
|
|
|
is_comp_op_ = false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// int64
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int64_t)
|
|
|
|
|
MS_LOG(ERROR) << "For 'BroadcastGpuOp', it only support these types: " << GetValidKernelTypes()
|
|
|
|
|
<< " currently, but got " << kernel_name_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// int8
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateDiv, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateMod, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, int8_t)
|
|
|
|
|
std::string BroadcastOpGpuKernelMod::GetValidKernelTypes() {
|
|
|
|
|
std::ostringstream valid_types;
|
|
|
|
|
valid_types << "Valid Compare Types: ";
|
|
|
|
|
std::for_each(kBroadcastCmpTypeMap.cbegin(), kBroadcastCmpTypeMap.cend(),
|
|
|
|
|
[&valid_types](const std::map<std::string, BroadcastOpType>::value_type &p) {
|
|
|
|
|
valid_types << p.first << std::string(", ");
|
|
|
|
|
});
|
|
|
|
|
valid_types << "; Valid Arithmetric Types: ";
|
|
|
|
|
std::for_each(kBroadcastArithmetricTypeMap.cbegin(), kBroadcastArithmetricTypeMap.cend(),
|
|
|
|
|
[&valid_types](const std::map<std::string, BroadcastOpType>::value_type &p) {
|
|
|
|
|
valid_types << p.first << std::string(", ");
|
|
|
|
|
});
|
|
|
|
|
return valid_types.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// uint32
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Sub, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint)
|
|
|
|
|
int BroadcastOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
|
|
|
|
const std::vector<KernelTensorPtr> &outputs,
|
|
|
|
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
|
|
|
|
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// uint8
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateDiv,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
TruncateMod,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint8_t)
|
|
|
|
|
lhs_shape_ = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
|
|
|
|
|
rhs_shape_ = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
|
|
|
|
|
output_shape_ = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector());
|
|
|
|
|
output_num_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
|
|
|
|
|
is_null_input_ = CHECK_SHAPE_NULL(lhs_shape_, kernel_name_, "input_0") ||
|
|
|
|
|
CHECK_SHAPE_NULL(rhs_shape_, kernel_name_, "input_1") ||
|
|
|
|
|
CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output_0");
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
return KRET_OK;
|
|
|
|
|
}
|
|
|
|
|
need_broadcast_ = broadcast_utils::IsBroadcast(lhs_shape_, rhs_shape_);
|
|
|
|
|
if (!broadcast_utils::AlignedBroadCastShape(MAX_DIMS, &output_shape_, &lhs_shape_, &rhs_shape_)) {
|
|
|
|
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than " << MAX_DIMS
|
|
|
|
|
<< ", but got " << lhs_shape_.size();
|
|
|
|
|
return KRET_RESIZE_FAILED;
|
|
|
|
|
}
|
|
|
|
|
return KRET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// int16
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int16_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int16_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
GreaterEqual,
|
|
|
|
|
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int16_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, int16_t)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
|
|
|
BroadcastOpGpuKernelMod, int16_t)
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool BroadcastOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
const std::vector<AddressPtr> &outputs) {
|
|
|
|
|
auto lhs = GetDeviceAddress<T>(inputs, kIndex0);
|
|
|
|
|
auto rhs = GetDeviceAddress<T>(inputs, kIndex1);
|
|
|
|
|
|
|
|
|
|
// uint16
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint16_t)
|
|
|
|
|
if (is_comp_op_) {
|
|
|
|
|
bool *output = GetDeviceAddress<bool>(outputs, kIndex0);
|
|
|
|
|
if (need_broadcast_) {
|
|
|
|
|
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
|
|
|
|
|
} else {
|
|
|
|
|
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
if (need_broadcast_) {
|
|
|
|
|
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
|
|
|
|
|
} else {
|
|
|
|
|
ElewiseArith(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// uint32
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint32_t)
|
|
|
|
|
std::vector<KernelAttr> BroadcastOpGpuKernelMod::GetOpSupport() {
|
|
|
|
|
std::vector<KernelAttr> support_list;
|
|
|
|
|
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
|
|
|
|
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair.first; });
|
|
|
|
|
return support_list;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// uint64
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
|
|
|
BroadcastOpGpuKernelMod, uint64_t)
|
|
|
|
|
std::vector<std::pair<KernelAttr, BroadcastOpGpuKernelMod::BroadCastFunc>> BroadcastOpGpuKernelMod::func_list_ = {
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<double>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
|
|
|
|
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
&BroadcastOpGpuKernelMod::LaunchKernel<double>}};
|
|
|
|
|
|
|
|
|
|
// bool
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, bool)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, bool)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, bool)
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, bool)
|
|
|
|
|
// uint64
|
|
|
|
|
MS_REG_GPU_KERNEL_ONE(
|
|
|
|
|
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
|
|
|
BroadcastOpGpuKernelMod, bool)
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Add, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Atan2, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AbsGrad, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Div, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DivNoNan, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Equal, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorMod, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorDiv, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Greater, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GreaterEqual, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Less, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LessEqual, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalOr, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalAnd, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mul, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mod, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Minimum, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Maximum, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NotEqual, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Pow, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RealDiv, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sub, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateDiv, BroadcastOpGpuKernelMod);
|
|
|
|
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateMod, BroadcastOpGpuKernelMod);
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|