Remove KRET_INVALID_SHAPE of IndexFill.

This commit is contained in:
hezhenhao1 2022-05-11 07:58:33 +08:00
parent fdb9c9cdd7
commit e0ba8eace6
2 changed files with 7 additions and 19 deletions

View File

@ -36,8 +36,6 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
return false; return false;
} }
kernel_func_ = func_list_[index].second; kernel_func_ = func_list_[index].second;
data_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex0).first));
dim_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex1).first));
return true; return true;
} }
@ -50,29 +48,21 @@ void IndexFillGpuKernelMod::ResetResource() noexcept {
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) { const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource();
int ret = KRET_OK;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_); CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
ResetResource();
x_shape_ = inputs.at(kIndex0)->GetShapeVector(); x_shape_ = inputs.at(kIndex0)->GetShapeVector();
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{}); x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
if (x_num_ <= 0) {
return KRET_INVALID_SHAPE;
}
auto index_shape = inputs.at(kIndex2)->GetShapeVector(); auto index_shape = inputs.at(kIndex2)->GetShapeVector();
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{}); index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
if (index_num_ < 0) {
return KRET_INVALID_SHAPE;
}
input_size_list_.push_back(LongToSize(x_num_) * data_size_);
input_size_list_.push_back(dim_size_);
input_size_list_.push_back(LongToSize(index_num_) * sizeof(kIndexType));
input_size_list_.push_back(data_size_);
output_size_list_.push_back(LongToSize(x_num_) * data_size_);
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound. workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
return KRET_OK; return ret;
} }
template <typename DataType, typename DimType> template <typename DataType, typename DimType>

View File

@ -62,8 +62,6 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
IndexFillLaunchFunc kernel_func_; IndexFillLaunchFunc kernel_func_;
int64_t x_num_; int64_t x_num_;
int64_t index_num_; int64_t index_num_;
size_t data_size_; // That is, sizeof(DataType).
size_t dim_size_; // That is, sizeof(DimType)
std::vector<int64_t> x_shape_{}; std::vector<int64_t> x_shape_{};
}; };
} // namespace kernel } // namespace kernel