From 594571fd4cac49393e78b8e1ea96f55b82afe6a3 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Wed, 21 Jul 2021 17:57:30 -0400 Subject: [PATCH] initial commit: fix 11 dts tickets fix ci --- .../kernel_compiler/gpu/arrays/argmax_gpu_kernel.h | 8 +++++++- .../kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h | 6 +++++- .../kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h | 3 +++ .../kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h | 3 ++- .../gpu/math/determinant_triangle_gpu_kernel.h | 8 +++++++- .../gpu/nn/adaptive_avg_pool2d_gpu_kernel.h | 6 ++++++ .../backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h | 6 +++++- .../backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h | 3 +++ .../gpu/nn/maxpool_with_argmax_gpu_kernel.h | 6 +++++- .../gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h | 8 ++++++-- 10 files changed, 49 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h index ed5e0dbb5c0..83eb5dc3149 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ class ArgmaxGpuKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); S *output = GetDeviceAddress(outputs, 0); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(output); CalArgmax(input, bound_, outer_size_, inner_size_, output, reinterpret_cast(stream_ptr)); return true; } @@ -46,6 +48,10 @@ class ArgmaxGpuKernel : public GpuKernel { auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); int64_t dims = shape.size(); int64_t axis = GetAttr(kernel_node, "axis"); + if (axis < -dims || axis >= dims) { + MS_LOG(EXCEPTION) << "axis must be in the range [-rank, rank)"; + } + if (axis < 0) { axis += dims; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h index ef792fa506b..c600157bb01 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,10 @@ class BroadcastToGpuKernel : public GpuKernel { MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than " << SHAPE_SIZE; } + if (output_shapes.size() < input_shapes.size()) { + MS_LOG(EXCEPTION) << "The rank of BroadcastTo's output cannot be smaller than the rank of the input."; + } + size_t offset = output_shapes.size() - input_shapes.size(); for (size_t i = 0; i < input_shapes.size(); i++) { input_shape_[i + offset] = input_shapes[i]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h index f7299596980..8192fbd9c38 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h @@ -53,6 +53,9 @@ class GatherV2GpuFwdKernel : public GpuKernel { Reshape(); } auto input_dim1 = input_shapes_[IntToSize(axis_)]; + + MS_EXCEPTION_IF_NULL(input_addr); + MS_EXCEPTION_IF_NULL(indices_addr); GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, reinterpret_cast(stream_ptr)); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h index 02a6287913a..8bdfbbb7077 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ class OnesLikeGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); int size = SizeToInt(input_size_ / sizeof(T)); + MS_EXCEPTION_IF_NULL(output); CalOnesLike(size, input, output, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h index 139941d41f9..1fcd6403fa8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,6 +73,12 @@ class DetTriangleGpuKernel : public GpuKernel { for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } + + if (input_shape.size() < 2) { + MS_LOG(ERROR) << "The input should have rank at least 2."; + return false; + } + matrix_n_ = input_shape[input_shape.size() - 1]; auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); for (size_t i = 0; i < output_shape.size(); i++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h index a2e9ab88fa1..78d20659586 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h @@ -79,6 +79,12 @@ class AdaptiveAvgPool2DKernel : public GpuKernel { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); len = static_cast(input_shape.size()); + + if (len < 2) { + MS_LOG(ERROR) << "The input should have rank at least 2."; + return false; + } + input_height = static_cast(input_shape[len - 2]); input_width = static_cast(input_shape[len - 1]); size = static_cast(len == 3 ? input_shape[0] : input_shape[0] * input_shape[1]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index aa16a83da6f..46cd0de59ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,6 +113,10 @@ class Conv2dGpuFwdKernel : public GpuKernel { std::vector pad_list_me = GetAttr>(kernel_node, "pad_list"); (void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list), [](const int64_t &value) { return static_cast(value); }); + if (pad_list.size() != 4) { + MS_LOG(EXCEPTION) << "Conv2dGpuFwdKernel pad_list must have length 4."; + } + pad_height_ = pad_list[0]; pad_width_ = pad_list[2]; use_pad_ = !((pad_height_ == pad_list[1]) && (pad_width_ == pad_list[3])); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h index 1d5b751889d..3c01afaa5e4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h @@ -90,6 +90,9 @@ class Conv3dGpuKernel : public GpuKernel { return true; } CheckTensorSize({in_shape}); + if (in_shape.size() != 5) { + MS_LOG(EXCEPTION) << "Conv3dGpuKernel input must have rank 5."; + } n_ = SizeToInt(in_shape[0]); c_ = SizeToInt(in_shape[1]); old_depth_ = SizeToInt(in_shape[2]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h index 634b6a24e7a..9c1d5a2c409 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,6 +123,10 @@ class MaxPoolWithArgmaxGpuFwdKernel : public GpuKernel { private: void SetPad() { + if (stride_height_ == 0) { + MS_LOG(EXCEPTION) << "stride height cannot be 0."; + } + pad_height_ = std::max( 0, (((input_height_ / stride_height_) * stride_height_ == input_height_ ? (input_height_ / stride_height_) : (input_height_ / stride_height_) + 1) - diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h index 472dc5a8b08..413db3927cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -171,6 +171,10 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { size_t logits_dim_length = logits_shape.size(); size_t labels_dim_length = labels_shape.size(); + if (logits_dim_length == 0) { + MS_LOG(EXCEPTION) << "Logits shape cannot be empty"; + } + if (labels_dim_length != logits_dim_length) { MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length for " "SoftmaxCrossEntropyWithLogits, but got Labels " @@ -178,7 +182,7 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { << labels_dim_length << ", Logits shape length:" << logits_dim_length; } if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { - MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; + MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last dimension."; } return; }