forked from mindspore-Ecosystem/mindspore
!29585 Fix Gather CPU kernel
Merge pull request !29585 from zuochuanyong/fix_gather_cpu
This commit is contained in:
commit
93c75b2dc4
|
@ -28,9 +28,23 @@ constexpr size_t kGatherOutputsNum = 1;
|
|||
constexpr size_t kGatherInputParamsMaxDim = 4;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void GatherV2CpuKernelMod<T>::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == kGatherInputsNum + 1) {
|
||||
is_dynamic_shape_ = true;
|
||||
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Dynamic Mode.";
|
||||
} else if (input_num == kGatherInputsNum) {
|
||||
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GatherV2CpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
CheckParam(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
|
@ -48,7 +62,6 @@ template <typename T>
|
|||
bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherOutputsNum, kernel_name_);
|
||||
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
|
||||
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||
|
|
|
@ -31,6 +31,8 @@ class GatherV2CpuKernelMod : public NativeCpuKernelMod {
|
|||
GatherV2CpuKernelMod() = default;
|
||||
~GatherV2CpuKernelMod() override = default;
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
|
Loading…
Reference in New Issue