forked from mindspore-Ecosystem/mindspore
Remove KRET_INVALID_SHAPE of IndexFill.
This commit is contained in:
parent
fdb9c9cdd7
commit
e0ba8eace6
|
@ -36,8 +36,6 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
kernel_func_ = func_list_[index].second;
|
kernel_func_ = func_list_[index].second;
|
||||||
data_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex0).first));
|
|
||||||
dim_size_ = GetTypeByte(TypeIdToType(kernel_attr.GetInputAttr(kIndex1).first));
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,29 +48,21 @@ void IndexFillGpuKernelMod::ResetResource() noexcept {
|
||||||
|
|
||||||
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
ResetResource();
|
||||||
|
int ret = KRET_OK;
|
||||||
|
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
|
||||||
ResetResource();
|
|
||||||
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||||
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
|
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
|
||||||
if (x_num_ <= 0) {
|
|
||||||
return KRET_INVALID_SHAPE;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||||
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
|
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
|
||||||
if (index_num_ < 0) {
|
|
||||||
return KRET_INVALID_SHAPE;
|
|
||||||
}
|
|
||||||
|
|
||||||
input_size_list_.push_back(LongToSize(x_num_) * data_size_);
|
|
||||||
input_size_list_.push_back(dim_size_);
|
|
||||||
input_size_list_.push_back(LongToSize(index_num_) * sizeof(kIndexType));
|
|
||||||
input_size_list_.push_back(data_size_);
|
|
||||||
output_size_list_.push_back(LongToSize(x_num_) * data_size_);
|
|
||||||
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
||||||
return KRET_OK;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataType, typename DimType>
|
template <typename DataType, typename DimType>
|
||||||
|
|
|
@ -62,8 +62,6 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
||||||
IndexFillLaunchFunc kernel_func_;
|
IndexFillLaunchFunc kernel_func_;
|
||||||
int64_t x_num_;
|
int64_t x_num_;
|
||||||
int64_t index_num_;
|
int64_t index_num_;
|
||||||
size_t data_size_; // That is, sizeof(DataType).
|
|
||||||
size_t dim_size_; // That is, sizeof(DimType)
|
|
||||||
std::vector<int64_t> x_shape_{};
|
std::vector<int64_t> x_shape_{};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
Loading…
Reference in New Issue