!48350 修复median GPU算子当输入为空算子时错误输出的问题

Merge pull request !48350 from wangtongyu6/fix_median_bug_gpu
This commit is contained in:
i-robot 2023-02-09 08:29:01 +00:00 committed by Gitee
commit ce87f6522d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 15 additions and 0 deletions

View File

@ -23,6 +23,7 @@
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
namespace mindspore {
namespace kernel {
@ -36,6 +37,9 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output0_addr = GetDeviceAddress<T>(outputs, 0);
S *output1_addr = nullptr;
@ -74,6 +78,16 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
if (ret != 0) {
return ret;
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
input_shapes.emplace_back(inp_shape);
std::vector<size_t> input_size_list;
int inp_flag =
cukernel::CalShapesSizeInBytes<T>(input_shapes, kMedianInputsNum, kernel_name_, "input_shapes", &input_size_list);
if (inp_flag == -1) {
return KRET_RESIZE_FAILED;
}
is_null_input_ = inp_flag == 1;
axis_ = attr_axis_;
input_shape_ = inputs[0]->GetShapeVector();
if (global_median_) {
@ -137,6 +151,7 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
int64_t attr_axis_;
int64_t axis_;
std::vector<int64_t> input_shape_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore