!40417 [assistant][ops] New GPU operator implementation, include MaxPool3D, MaxPool3DGrad

Merge pull request !40417 from 黎冠新/MaxPool3D
This commit is contained in:
i-robot 2022-11-29 13:50:58 +00:00 committed by Gitee
commit 4d34360b39
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 64 additions and 65 deletions

View File

@ -26,6 +26,8 @@ MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add
PoolingFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PoolingFwdGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
PoolingFwdGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -30,6 +30,7 @@
namespace mindspore {
namespace kernel {
constexpr auto kNumberFive = 5;
constexpr auto kAvgPool = "AvgPool";
constexpr auto kAvgPool3D = "AvgPool3D";
@ -72,14 +73,21 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, input_addr, &beta,
output_descriptor_, output_addr),
"cudnnPoolingForward failed");
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, input_addr, &beta,
output_descriptor_, output_addr),
"cudnnPoolingForward failed");
} else {
const float alphaf = static_cast<float>(alpha);
const float betaf = static_cast<float>(beta);
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alphaf, input_descriptor_, input_addr, &betaf,
output_descriptor_, output_addr),
"cudnnPoolingForward failed");
}
if (divisor_override_ != 0) {
T *work_addr = GetDeviceAddress<T>(workspace, 0);
size_t output_num = output_size_ / sizeof(T);
@ -129,6 +137,12 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
auto input_shape = inputs[0]->GetDeviceShapeAdaptively();
auto output_shape = outputs[0]->GetDeviceShapeAdaptively();
@ -299,7 +313,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
[](const int64_t &value) { return static_cast<int>(value); });
int windowDimA[3] = {window_depth, window_height, window_width};
int paddingA[3] = {0, 0, 0};
if (stride_.size() < 5) {
if (stride_.size() < kNumberFive) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'strides' cannot be less than 5, but got "
<< stride_.size();
}

View File

@ -138,10 +138,8 @@ bool PoolingGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
dy = GetDeviceAddress<T>(inputs, kIndex2);
dx = GetDeviceAddress<T>(outputs, kIndex0);
}
const float alpha = 1;
const float beta = 0;
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
if (divisor_override_ != 0) {
T *work_addr = GetDeviceAddress<T>(workspace, kIndex2);
T *dy_work_addr = GetDeviceAddress<T>(workspace, kIndex3);
@ -163,18 +161,36 @@ bool PoolingGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
}
ElewiseArith(output_num, BROADCAST_TYPE_MUL, dy_work_addr, work_addr, dy_work_addr,
reinterpret_cast<cudaStream_t>(cuda_stream_));
if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy_work_addr,
x_descriptor_, x_data, &beta, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
} else {
const float alphaf = static_cast<float>(alpha);
const float betaf = static_cast<float>(beta);
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alphaf, y_descriptor_, y, dy_descriptor_,
dy_work_addr, x_descriptor_, x_data, &betaf, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy_work_addr,
x_descriptor_, x_data, &beta, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
return true;
}
if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy,
x_descriptor_, x_data, &beta, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
} else {
const float alphaf = static_cast<float>(alpha);
const float betaf = static_cast<float>(beta);
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alphaf, y_descriptor_, y, dy_descriptor_, dy,
x_descriptor_, x_data, &betaf, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy,
x_descriptor_, x_data, &beta, dx_descriptor_, dx),
"For '" + kernel_name_ + "', cudnnPoolingBackward failed");
return true;
}
@ -497,7 +513,13 @@ std::map<std::string, std::vector<std::pair<KernelAttr, PoolingGradGpuKernelMod:
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&PoolingGradGpuKernelMod::LaunchKernel<half>}}},
&PoolingGradGpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&PoolingGradGpuKernelMod::LaunchKernel<double>}}},
{kAvgPoolGrad,
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -73,7 +73,7 @@ TypePtr MaxPool3DGradInferType(const PrimitivePtr &primitive, const std::vector<
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("input", x_dtype, valid_types, op_name);
}

View File

@ -211,7 +211,7 @@ TypePtr MaxPool3DInferType(const PrimitivePtr &primitive, const std::vector<Abst
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, op_name);
}
} // namespace

View File

@ -1893,7 +1893,7 @@ class MaxPoolWithArgmax(Primitive):
self.add_prim_attr("strides", self.strides)
class MaxPool3D(PrimitiveWithInfer):
class MaxPool3D(Primitive):
r"""
3D max pooling operation.
@ -1940,7 +1940,7 @@ class MaxPool3D(PrimitiveWithInfer):
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
Data type must be float16 or float32.
Data type must be float16, float32 or float64.
Outputs:
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the data type of `x`.
@ -2007,45 +2007,6 @@ class MaxPool3D(PrimitiveWithInfer):
validator.check_non_negative_int(item, 'pad_list item', self.name)
self.add_prim_attr("pad_list", self.pad_list)
def infer_shape(self, x_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
batch, channel, input_d, input_h, input_w = x_shape
self.add_prim_attr("x_shape", x_shape)
_, _, kernel_d, kernel_h, kernel_w = self.kernel_size
_, _, stride_d, stride_h, stride_w = self.strides
if self.pad_mode == "VALID":
out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d)
out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
elif self.pad_mode == "SAME":
out_d = math.ceil(input_d / stride_d)
out_h = math.ceil(input_h / stride_h)
out_w = math.ceil(input_w / stride_w)
else:
out_d = ((input_d + self.pad_list[0] + self.pad_list[1] -
(kernel_d - 1) - 1) / stride_d) + 1
out_h = ((input_h + self.pad_list[2] + self.pad_list[3] -
(kernel_h - 1) - 1) / stride_h) + 1
out_w = ((input_w + self.pad_list[4] + self.pad_list[5] -
(kernel_w - 1) - 1) / stride_w) + 1
if self.ceil_mode:
out_d = math.ceil(out_d)
out_h = math.ceil(out_h)
out_w = math.ceil(out_w)
else:
out_d = math.floor(out_d)
out_h = math.floor(out_h)
out_w = math.floor(out_w)
out_shape = [batch, channel, out_d, out_h, out_w]
_check_shape('output', out_shape, self.name)
return out_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class MaxUnpool2D(Primitive):
r"""