forked from mindspore-Ecosystem/mindspore
!48350 修复median GPU算子当输入为空算子时错误输出的问题
Merge pull request !48350 from wangtongyu6/fix_median_bug_gpu
This commit is contained in:
commit
ce87f6522d
|
@ -23,6 +23,7 @@
|
|||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,6 +37,9 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output0_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S *output1_addr = nullptr;
|
||||
|
@ -74,6 +78,16 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
std::vector<size_t> input_size_list;
|
||||
int inp_flag =
|
||||
cukernel::CalShapesSizeInBytes<T>(input_shapes, kMedianInputsNum, kernel_name_, "input_shapes", &input_size_list);
|
||||
if (inp_flag == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
is_null_input_ = inp_flag == 1;
|
||||
axis_ = attr_axis_;
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
if (global_median_) {
|
||||
|
@ -137,6 +151,7 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
int64_t attr_axis_;
|
||||
int64_t axis_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue