!39285 AdaptiveMaxpool2d grad gpu ops bugfix

Merge pull request !39285 from 胡安东/0730bugfix
This commit is contained in:
i-robot 2022-08-01 12:14:24 +00:00 committed by Gitee
commit db23bd8b8d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 17 additions and 9 deletions

View File

@ -62,8 +62,9 @@ bool AdaptiveMaxPool2DGradGpuKernelMod::Launch(const std::vector<AddressPtr> &in
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size),
"failed to set cuda memory with zeros.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
"failed to set cuda memory with zeros.");
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
@ -85,11 +86,6 @@ bool AdaptiveMaxPool2DGradGpuKernelMod::Init(const BaseOperatorPtr &base_operato
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
int ret = Resize(kernel_ptr, inputs, outputs);
if (ret == KRET_RESIZE_FAILED) {
return false;
}
return true;
}
@ -97,6 +93,11 @@ int AdaptiveMaxPool2DGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operat
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();

View File

@ -29,6 +29,8 @@ namespace mindspore {
namespace ops {
namespace {
constexpr size_t inputArgLen = 3;
constexpr int64_t kDynamicRankVal = -2;
constexpr size_t kDynamicRankL = 1;
abstract::ShapePtr AdaptiveMaxPool2DGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -49,12 +51,17 @@ abstract::ShapePtr AdaptiveMaxPool2DGradInferShape(const PrimitivePtr &primitive
const int64_t x_dims = SizeToLong(x_shape.size());
const int64_t argmax_dims = SizeToLong(argmax_shape.size());
(void)CheckAndConvertUtils::CheckInteger("y_grad_dims", y_grad_dims, kEqual, x_dims, kNameAdaptiveMaxPool2DGrad);
(void)CheckAndConvertUtils::CheckInteger("argmax_dims", argmax_dims, kEqual, x_dims, kNameAdaptiveMaxPool2DGrad);
if (y_grad_shape.size() == kDynamicRankL && y_grad_shape[0] == kDynamicRankVal) {
ShapeVector out_shape = {kDynamicRankVal};
return std::make_shared<abstract::Shape>(out_shape);
}
CheckAndConvertUtils::CheckInRange("y_grad_dim", y_grad_dims, kIncludeBoth, {3, 4}, kNameAdaptiveMaxPool2DGrad);
CheckAndConvertUtils::CheckInRange("x_dim", x_dims, kIncludeBoth, {3, 4}, kNameAdaptiveMaxPool2DGrad);
CheckAndConvertUtils::CheckInRange("argmax_dim", argmax_dims, kIncludeBoth, {3, 4}, kNameAdaptiveMaxPool2DGrad);
(void)CheckAndConvertUtils::CheckInteger("y_grad_dims", y_grad_dims, kEqual, x_dims, kNameAdaptiveMaxPool2DGrad);
(void)CheckAndConvertUtils::CheckInteger("argmax_dims", argmax_dims, kEqual, x_dims, kNameAdaptiveMaxPool2DGrad);
if (y_grad_shape != argmax_shape) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the shape of 'y_grad' should be consistent with the shape of 'argmax'.";