From e0ba8eace6eae42a61cf62cdb84bf56a4fa7b6e6 Mon Sep 17 00:00:00 2001 From: hezhenhao1 Date: Wed, 11 May 2022 07:58:33 +0800 Subject: [PATCH] Remove KRET_INVALID_SHAPE of IndexFill. --- .../kernel/arrays/index_fill_gpu_kernel.cc | 24 ++++++------------- .../gpu/kernel/arrays/index_fill_gpu_kernel.h | 2 -- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.cc index 1b80d866bd3..48f0bc6896a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.cc @@ -36,8 +36,6 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std return false; } kernel_func_ = func_list_[index].second; - data_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex0).first)); - dim_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex1).first)); return true; } @@ -50,29 +48,21 @@ void IndexFillGpuKernelMod::ResetResource() noexcept { int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs, - const std::map &) { + const std::map &inputsOnHost) { + ResetResource(); + int ret = KRET_OK; + if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) { + return ret; + } CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_); - ResetResource(); x_shape_ = inputs.at(kIndex0)->GetShapeVector(); x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{}); - if (x_num_ <= 0) { - return KRET_INVALID_SHAPE; - } auto index_shape = inputs.at(kIndex2)->GetShapeVector(); index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{}); - if (index_num_ < 0) { - return KRET_INVALID_SHAPE; - } - - input_size_list_.push_back(LongToSize(x_num_) * data_size_); - input_size_list_.push_back(dim_size_); - input_size_list_.push_back(LongToSize(index_num_) * sizeof(kIndexType)); - input_size_list_.push_back(data_size_); - output_size_list_.push_back(LongToSize(x_num_) * data_size_); workspace_size_list_.push_back(sizeof(bool)); // Place out_bound. - return KRET_OK; + return ret; } template diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.h index 625e299ba41..73adb7b105f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.h @@ -62,8 +62,6 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod { IndexFillLaunchFunc kernel_func_; int64_t x_num_; int64_t index_num_; - size_t data_size_; // That is, sizeof(DataType). - size_t dim_size_; // That is, sizeof(DimType) std::vector x_shape_{}; }; } // namespace kernel