forked from mindspore-Ecosystem/mindspore
Remove KRET_INVALID_SHAPE of IndexFill.
This commit is contained in:
parent
fdb9c9cdd7
commit
e0ba8eace6
|
@ -36,8 +36,6 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
data_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex0).first));
|
||||
dim_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex1).first));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -50,29 +48,21 @@ void IndexFillGpuKernelMod::ResetResource() noexcept {
|
|||
|
||||
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
ResetResource();
|
||||
int ret = KRET_OK;
|
||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
|
||||
ResetResource();
|
||||
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
|
||||
if (x_num_ <= 0) {
|
||||
return KRET_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
|
||||
if (index_num_ < 0) {
|
||||
return KRET_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
input_size_list_.push_back(LongToSize(x_num_) * data_size_);
|
||||
input_size_list_.push_back(dim_size_);
|
||||
input_size_list_.push_back(LongToSize(index_num_) * sizeof(kIndexType));
|
||||
input_size_list_.push_back(data_size_);
|
||||
output_size_list_.push_back(LongToSize(x_num_) * data_size_);
|
||||
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
||||
return KRET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename DataType, typename DimType>
|
||||
|
|
|
@ -62,8 +62,6 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
|||
IndexFillLaunchFunc kernel_func_;
|
||||
int64_t x_num_;
|
||||
int64_t index_num_;
|
||||
size_t data_size_; // That is, sizeof(DataType).
|
||||
size_t dim_size_; // That is, sizeof(DimType)
|
||||
std::vector<int64_t> x_shape_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
Loading…
Reference in New Issue