!29585 Fix Gather CPU kernel

Merge pull request !29585 from zuochuanyong/fix_gather_cpu
This commit is contained in:
i-robot 2022-02-07 09:23:18 +00:00 committed by Gitee
commit 93c75b2dc4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 16 additions and 1 deletions

View File

@ -28,9 +28,23 @@ constexpr size_t kGatherOutputsNum = 1;
constexpr size_t kGatherInputParamsMaxDim = 4;
} // namespace
template <typename T>
void GatherV2CpuKernelMod<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == kGatherInputsNum + 1) {
is_dynamic_shape_ = true;
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Dynamic Mode.";
} else if (input_num == kGatherInputsNum) {
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2.";
}
}
template <typename T>
void GatherV2CpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
@ -48,7 +62,6 @@ template <typename T>
bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherOutputsNum, kernel_name_);
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);

View File

@ -31,6 +31,8 @@ class GatherV2CpuKernelMod : public NativeCpuKernelMod {
GatherV2CpuKernelMod() = default;
~GatherV2CpuKernelMod() override = default;
void CheckParam(const CNodePtr &kernel_node);
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,