forked from mindspore-Ecosystem/mindspore
!4048 Fix broadcast, scatternd, reduce ops.
Merge pull request !4048 from linqingke/new_ops
This commit is contained in:
commit
eb84ae4593
|
@ -182,30 +182,59 @@ class ArrayReduceGpuKernel : public GpuKernel {
|
|||
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
|
||||
std::vector<int> inputA;
|
||||
std::vector<size_t> outputC_shape = output_shape;
|
||||
ShapeNdTo4d(input_shape, &inputA);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0],
|
||||
inputA[1], inputA[2], inputA[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
const int split_dim = 4;
|
||||
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &inputA);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
|
||||
inputA[0], inputA[1], inputA[2], inputA[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_);
|
||||
for (auto dim : input_shape) {
|
||||
inputA.emplace_back(SizeToInt(dim));
|
||||
}
|
||||
}
|
||||
|
||||
if (axis_[0] == -1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) {
|
||||
all_match_ = true;
|
||||
outputC_shape.resize(input_shape.size(), 1);
|
||||
if (outputC_shape.size() <= split_dim) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
|
||||
}
|
||||
|
||||
for (auto dim : inputA) {
|
||||
if (dim != 1) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
all_match_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> outputC;
|
||||
if (!keep_dims_) {
|
||||
for (auto i : axis_) {
|
||||
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
|
||||
}
|
||||
}
|
||||
std::vector<int> outputC;
|
||||
ShapeNdTo4d(outputC_shape, &outputC);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
|
||||
outputC[0], outputC[1], outputC[2], outputC[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
if (outputC_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(outputC_shape, &outputC);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
|
||||
outputC[0], outputC[1], outputC[2], outputC[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
|
||||
for (auto dim : outputC_shape) {
|
||||
outputC.emplace_back(SizeToInt(dim));
|
||||
}
|
||||
}
|
||||
|
||||
if (inputA == outputC) {
|
||||
all_match_ = true;
|
||||
}
|
||||
|
|
|
@ -69,6 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
|
|||
memcpy_flag_ = true;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemSet failed in ScatterNdGpuFwdKernel::Launch.");
|
||||
|
||||
const size_t input_size = input_size_ / sizeof(T);
|
||||
const size_t output_size = output_size_ / sizeof(T);
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
|
@ -107,69 +108,97 @@ __device__ __forceinline__ int Index(const int &index, const int &dim) { return
|
|||
|
||||
template <typename T, typename S, typename Func>
|
||||
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
||||
const int &r0, const int &r1, const int &r2, const int &r3,
|
||||
const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
const T *input0, const T *input1, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (d1 * d2 * d3) % d0;
|
||||
int j = pos / (d2 * d3) % d1;
|
||||
int k = pos / d3 % d2;
|
||||
int l = pos % d3;
|
||||
const int &l4, const int &l5, const int &l6, const int &r0,
|
||||
const int &r1, const int &r2, const int &r3, const int &r4,
|
||||
const int &r5, const int &r6, const int &d0, const int &d1,
|
||||
const int &d2, const int &d3, const int &d4, const int &d5,
|
||||
const int &d6, const T *input0, const T *input1, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
|
||||
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
|
||||
int k = pos / (d3 * d4 * d5 * d6) % d2;
|
||||
int l = pos / (d4 * d5 * d6) % d3;
|
||||
int m = pos / (d5 * d6) % d4;
|
||||
int n = pos / d6 % d5;
|
||||
int o = pos % d6;
|
||||
|
||||
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
|
||||
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
|
||||
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
|
||||
l_index += Index(l, l3) * l4 * l5 * l6;
|
||||
l_index += Index(m, l4) * l5 * l6;
|
||||
l_index += Index(n, l5) * l6;
|
||||
l_index += Index(o, l6);
|
||||
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
|
||||
r_index += Index(l, r3) * r4 * r5 * r6;
|
||||
r_index += Index(m, r4) * r5 * r6;
|
||||
r_index += Index(n, r5) * r6;
|
||||
r_index += Index(o, r6);
|
||||
output[pos] = Func()(input0[l_index], input1[r_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
|
||||
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
|
||||
enum BroadcastOpType op, const T *input0, const T *input1, S *output) {
|
||||
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
|
||||
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
|
||||
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
|
||||
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
|
||||
const T *input1, S *output) {
|
||||
switch (op) {
|
||||
case BROADCAST_TYPE_GREATER:
|
||||
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_LESS:
|
||||
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MINIMUM:
|
||||
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MAXIMUM:
|
||||
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_POWER:
|
||||
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_REALDIV:
|
||||
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MUL:
|
||||
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_SUB:
|
||||
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_ADD:
|
||||
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_FLOORDIV:
|
||||
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_ABSGRAD:
|
||||
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
|
||||
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
|
||||
const T *input0, const T *input1, S *output, cudaStream_t stream) {
|
||||
int size = d0 * d1 * d2 * d3;
|
||||
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
|
||||
input0, input1, output);
|
||||
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
|
||||
S *output, cudaStream_t stream) {
|
||||
int size = 1;
|
||||
for (auto d : output_shape) {
|
||||
size *= d;
|
||||
}
|
||||
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3],
|
||||
lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
|
||||
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4],
|
||||
rhs_shape[5], rhs_shape[6], output_shape[0],
|
||||
output_shape[1], output_shape[2], output_shape[3],
|
||||
output_shape[4], output_shape[5], output_shape[6],
|
||||
op, input0, input1, output);
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename Func>
|
||||
|
@ -236,30 +265,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
|
|||
output_addr);
|
||||
}
|
||||
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const float *input0, const float *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const half *input0, const half *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const half *input0, const half *input1, half *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
||||
const float *input1, bool *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
||||
const float *input1, float *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
||||
const half *input1, bool *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
||||
const half *input1, half *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
||||
const int *input1, int *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
||||
const int *input1, bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
||||
bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
|
||||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
enum BroadcastOpType {
|
||||
|
@ -35,9 +36,9 @@ enum BroadcastOpType {
|
|||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
|
||||
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
|
||||
const T *input0, const T *input1, S *output, cudaStream_t stream);
|
||||
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
|
||||
S *output, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
|
||||
|
|
|
@ -25,10 +25,10 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m
|
|||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
S valid_flag = false;
|
||||
valid_flag |= !(box[left_x] >= 0.f);
|
||||
valid_flag |= !(box[left_y] >= 0.f);
|
||||
valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]);
|
||||
valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]);
|
||||
valid_flag |= !(box[left_x] >= static_cast<T>(0.0));
|
||||
valid_flag |= !(box[left_y] >= static_cast<T>(0.0));
|
||||
valid_flag |= !(img_metas[1] * img_metas[2] - static_cast<T>(1.0) >= box[right_x]);
|
||||
valid_flag |= !(img_metas[0] * img_metas[2] - static_cast<T>(1.0) >= box[right_y]);
|
||||
|
||||
valid[i] = !valid_flag;
|
||||
}
|
||||
|
@ -43,3 +43,5 @@ void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid,
|
|||
|
||||
template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -16,27 +16,26 @@
|
|||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMax(const T a, const T b) {
|
||||
__device__ float CoordinateMax(const float a, const float b) {
|
||||
return (a > b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMin(const T a, const T b) {
|
||||
__device__ float CoordinateMin(const float a, const float b) {
|
||||
return (a < b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode,
|
||||
const size_t input_len_0) {
|
||||
T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
|
||||
T overlaps_coordinate[IOU_DIMENSION];
|
||||
const T epsilon = 1e-10;
|
||||
float location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
|
||||
float overlaps_coordinate[IOU_DIMENSION];
|
||||
const float epsilon = 1e-10;
|
||||
const float offset = 1.0;
|
||||
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
for (size_t j = 0; j < IOU_DIMENSION; j++) {
|
||||
location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j];
|
||||
location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j];
|
||||
location_coordinate[0][j] = static_cast<float>(box1[(i % input_len_0) * IOU_DIMENSION + j]);
|
||||
location_coordinate[1][j] = static_cast<float>(box2[(i / input_len_0) * IOU_DIMENSION + j]);
|
||||
}
|
||||
|
||||
overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]);
|
||||
|
@ -44,18 +43,18 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io
|
|||
overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]);
|
||||
overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]);
|
||||
|
||||
T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1);
|
||||
T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1);
|
||||
T overlaps = overlaps_w * overlaps_h;
|
||||
float overlaps_w = CoordinateMax(0.0, overlaps_coordinate[2] - overlaps_coordinate[0] + offset);
|
||||
float overlaps_h = CoordinateMax(0.0, overlaps_coordinate[3] - overlaps_coordinate[1] + offset);
|
||||
float overlaps = overlaps_w * overlaps_h;
|
||||
|
||||
T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] -
|
||||
location_coordinate[0][1] + 1);
|
||||
T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] -
|
||||
location_coordinate[1][1] + 1);
|
||||
float area1 = (location_coordinate[0][2] - location_coordinate[0][0] + offset) * (location_coordinate[0][3] -
|
||||
location_coordinate[0][1] + offset);
|
||||
float area2 = (location_coordinate[1][2] - location_coordinate[1][0] + offset) * (location_coordinate[1][3] -
|
||||
location_coordinate[1][1] + offset);
|
||||
if (mode == 0) {
|
||||
iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon);
|
||||
iou_results[i] = static_cast<T>(overlaps / (area1 + area2 - overlaps + epsilon));
|
||||
} else {
|
||||
iou_results[i] = overlaps / (area2 + epsilon);
|
||||
iou_results[i] = static_cast<T>(overlaps / (area2 + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,3 +69,5 @@ void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const
|
|||
|
||||
template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
||||
template void IOU(const size_t &size, const half *box1, const half *box2, half *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -84,6 +84,40 @@ class GpuKernel : public KernelMod {
|
|||
}
|
||||
}
|
||||
|
||||
// set the tensor descriptor for cudnn/cublas
|
||||
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
|
||||
cudnnDataType_t data_type) {
|
||||
if (shape.size() < 3) {
|
||||
MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D.";
|
||||
}
|
||||
const int nbDims = shape.size();
|
||||
int *dim = new (std::nothrow) int[nbDims];
|
||||
if (dim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "malloc dim failed.";
|
||||
}
|
||||
int *stride = new (std::nothrow) int[nbDims];
|
||||
if (stride == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "malloc stride failed.";
|
||||
}
|
||||
|
||||
for (int i = 0; i < nbDims; i++) {
|
||||
dim[i] = SizeToInt(shape[i]);
|
||||
stride[i] = 1;
|
||||
}
|
||||
|
||||
for (int i = nbDims - 2; i >= 0; i--) {
|
||||
stride[i] = stride[i + 1] * SizeToInt(shape[i + 1]);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(descriptor, data_type, nbDims, dim, stride),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
|
||||
delete[] dim;
|
||||
dim = nullptr;
|
||||
delete[] stride;
|
||||
stride = nullptr;
|
||||
}
|
||||
|
||||
// choose the suitable datatype for cudnn/cublas
|
||||
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
|
||||
auto type = kCudnnDtypeMap.find(Type);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T, typename S>
|
||||
class BroadcastOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -45,9 +46,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
S *output = GetDeviceAddress<S>(outputs, 0);
|
||||
|
||||
if (need_broadcast_) {
|
||||
Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2],
|
||||
rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs,
|
||||
rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
@ -60,10 +60,13 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
need_broadcast_ = IsBroadcast(shape1, shape2);
|
||||
if (need_broadcast_ && shape1.size() > 4) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4";
|
||||
if (need_broadcast_ && shape1.size() > 7) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
|
||||
}
|
||||
|
||||
lhs_shape_.resize(MAX_DIMS, 1);
|
||||
rhs_shape_.resize(MAX_DIMS, 1);
|
||||
output_shape_.resize(MAX_DIMS, 1);
|
||||
for (size_t i = 0; i < shape3.size(); i++) {
|
||||
output_shape_[i] = shape3[i];
|
||||
output_num_ *= shape3[i];
|
||||
|
@ -127,9 +130,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
int input1_num_;
|
||||
int input2_num_;
|
||||
int output_num_;
|
||||
int lhs_shape_[4] = {1, 1, 1, 1};
|
||||
int rhs_shape_[4] = {1, 1, 1, 1};
|
||||
int output_shape_[4] = {1, 1, 1, 1};
|
||||
std::vector<int> lhs_shape_;
|
||||
std::vector<int> rhs_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -83,12 +83,19 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
|
||||
"cudnnSetActivationDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
const int split_dim = 4;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -90,12 +90,18 @@ class ActivationGradGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0),
|
||||
"SetActivationDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"SetTensor4dDescriptor failed");
|
||||
|
||||
const int split_dim = 4;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"SetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
|
|
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidGpuKernel, float, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidGpuKernel, half, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,5 +21,8 @@ namespace kernel {
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IOUGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
IOUGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,8 +37,8 @@ def test_floor_div():
|
|||
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
|
||||
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
|
||||
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32)
|
||||
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32)
|
||||
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
|
||||
x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32)
|
||||
y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32)
|
||||
x3_np = np.random.randint(1, 5, 1).astype(np.float32)
|
||||
y3_np = np.random.randint(1, 5, 1).astype(np.float32)
|
||||
x4_np = np.array(768).astype(np.float32)
|
||||
|
|
|
@ -70,7 +70,7 @@ x11 = np.random.rand(1, 1, 1, 1).astype(np.float32)
|
|||
axis11 = (0, 1, 2, 3)
|
||||
keep_dims11 = False
|
||||
|
||||
x12 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
x12 = np.random.rand(2, 3, 4, 4, 5, 6).astype(np.float32)
|
||||
axis12 = -2
|
||||
keep_dims12 = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue