forked from mindspore-Ecosystem/mindspore
!4298 Execution MaxmimumGrad on demand
Merge pull request !4298 from chenweifeng/maximumgrad
This commit is contained in:
commit
6e3c87be46
|
@ -19,10 +19,11 @@
|
|||
|
||||
template <typename T>
|
||||
struct MinimumGradFunc {
|
||||
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) {
|
||||
if (x1 < x2) {
|
||||
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
|
||||
const T &dy, T *dx1, T *dx2) {
|
||||
if (grad_x1 && x1 < x2) {
|
||||
atomicAdd(dx1, dy);
|
||||
} else {
|
||||
} else if (grad_x2 && x1 >= x2) {
|
||||
atomicAdd(dx2, dy);
|
||||
}
|
||||
}
|
||||
|
@ -30,10 +31,11 @@ struct MinimumGradFunc {
|
|||
|
||||
template <typename T>
|
||||
struct MaximumGradFunc {
|
||||
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) {
|
||||
if (x1 > x2) {
|
||||
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
|
||||
const T &dy, T *dx1, T *dx2) {
|
||||
if (grad_x1 && x1 > x2) {
|
||||
atomicAdd(dx1, dy);
|
||||
} else {
|
||||
} else if (grad_x2 && x1 <= x2) {
|
||||
atomicAdd(dx2, dy);
|
||||
}
|
||||
}
|
||||
|
@ -45,7 +47,8 @@ template <typename T, typename Func>
|
|||
__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
||||
const int &r0, const int &r1, const int &r2, const int &r3,
|
||||
const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) {
|
||||
const bool &grad_x1, const bool &grad_x2, const T *x1,
|
||||
const T *x2, const T *dy, T *dx1, T *dx2) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (d1 * d2 * d3) % d0;
|
||||
int j = pos / (d2 * d3) % d1;
|
||||
|
@ -54,69 +57,71 @@ __device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &
|
|||
|
||||
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
|
||||
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
|
||||
Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index);
|
||||
Func()(x1[l_index], x2[r_index], grad_x1, grad_x2, dy[pos], dx1 + l_index, dx2 + r_index);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
|
||||
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
|
||||
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
||||
T *dx2) {
|
||||
const bool grad_x1, const bool grad_x2, enum BroadcastGradOpType op, const T *x1,
|
||||
const T *x2, const T *dy, T *dx1, T *dx2) {
|
||||
switch (op) {
|
||||
case BROADCAST_GRAD_TYPE_MINIMUM:
|
||||
return BroadcastGradOperator<T, MinimumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy,
|
||||
dx1, dx2);
|
||||
return BroadcastGradOperator<T, MinimumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, grad_x1,
|
||||
grad_x2, x1, x2, dy, dx1, dx2);
|
||||
case BROADCAST_GRAD_TYPE_MAXIMUM:
|
||||
return BroadcastGradOperator<T, MaximumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy,
|
||||
dx1, dx2);
|
||||
return BroadcastGradOperator<T, MaximumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, grad_x1,
|
||||
grad_x2, x1, x2, dy, dx1, dx2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2,
|
||||
cudaStream_t stream) {
|
||||
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const T *x1, const T *x2,
|
||||
const T *dy, T *dx1, T *dx2, cudaStream_t stream) {
|
||||
int size = d0 * d1 * d2 * d3;
|
||||
BroadcastGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
|
||||
x1, x2, dy, dx1, dx2);
|
||||
BroadcastGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3,
|
||||
grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2);
|
||||
}
|
||||
|
||||
template <typename T, typename Func>
|
||||
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1,
|
||||
T *dx2) {
|
||||
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const bool &grad_x1, const bool &grad_x2,
|
||||
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
|
||||
Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos);
|
||||
Func()(x1[pos], x2[pos], grad_x1, grad_x2, dy[pos], dx1 + pos, dx2 + pos);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2,
|
||||
const T *dy, T *dx1, T *dx2) {
|
||||
__global__ void NoBroadcastGradKernel(const int nums, const bool grad_x1, const bool grad_x2,
|
||||
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
||||
T *dx2) {
|
||||
switch (op) {
|
||||
case BROADCAST_GRAD_TYPE_MINIMUM:
|
||||
return NoBroadcastOperator<T, MinimumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2);
|
||||
return NoBroadcastOperator<T, MinimumGradFunc<T>>(nums, grad_x1, grad_x2, x1, x2, dy, dx1, dx2);
|
||||
case BROADCAST_GRAD_TYPE_MAXIMUM:
|
||||
return NoBroadcastOperator<T, MaximumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2);
|
||||
return NoBroadcastOperator<T, MaximumGradFunc<T>>(nums, grad_x1, grad_x2, x1, x2, dy, dx1, dx2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
||||
T *dx2, cudaStream_t stream) {
|
||||
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, x1, x2, dy, dx1, dx2);
|
||||
void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
||||
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, cudaStream_t stream) {
|
||||
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2);
|
||||
}
|
||||
|
||||
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2,
|
||||
const float *dy, float *dx1, float *dx2, cudaStream_t stream);
|
||||
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2,
|
||||
const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
||||
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
||||
const float *x1, const float *x2, const float *dy, float *dx1, float *dx2,
|
||||
cudaStream_t stream);
|
||||
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
||||
const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
||||
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1,
|
||||
float *dx2, cudaStream_t stream);
|
||||
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1,
|
||||
const float *x2, const float *dy, float *dx1, float *dx2, cudaStream_t stream);
|
||||
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1,
|
||||
int *dx2, cudaStream_t stream);
|
||||
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1,
|
||||
const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream);
|
||||
|
|
|
@ -28,11 +28,11 @@ enum BroadcastGradOpType {
|
|||
template <typename T>
|
||||
void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2,
|
||||
cudaStream_t stream);
|
||||
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const T *x1, const T *x2,
|
||||
const T *dy, T *dx1, T *dx2, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1,
|
||||
T *dx2, cudaStream_t stream);
|
||||
void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op,
|
||||
const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_
|
||||
|
|
|
@ -31,7 +31,13 @@ template <typename T>
|
|||
class BroadcastOpGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
BroadcastOpGradGpuKernel()
|
||||
: op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {}
|
||||
: op_type_(BROADCAST_GRAD_TYPE_INVALID),
|
||||
need_broadcast_(false),
|
||||
input1_num_(1),
|
||||
input2_num_(1),
|
||||
output_num_(1),
|
||||
grad_x_(false),
|
||||
grad_y_(false) {}
|
||||
~BroadcastOpGradGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -52,10 +58,11 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
|
|||
"cudaMemSet Failed");
|
||||
if (need_broadcast_) {
|
||||
BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2],
|
||||
x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1,
|
||||
dx2, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], grad_x_, grad_y_, op_type_,
|
||||
x1, x2, dy, dx1, dx2, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
NoBroadcastGrad(output_num_, grad_x_, grad_y_, op_type_, x1, x2, dy, dx1, dx2,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -85,6 +92,9 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
|
|||
input2_num_ *= shape2[i];
|
||||
}
|
||||
|
||||
grad_x_ = GetAttr<bool>(kernel_node, "grad_x");
|
||||
grad_y_ = GetAttr<bool>(kernel_node, "grad_y");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -136,6 +146,8 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
|
|||
int x1_shape_[4] = {1, 1, 1, 1};
|
||||
int x2_shape_[4] = {1, 1, 1, 1};
|
||||
int dy_shape_[4] = {1, 1, 1, 1};
|
||||
bool grad_x_;
|
||||
bool grad_y_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
Loading…
Reference in New Issue