forked from mindspore-Ecosystem/mindspore
!29585 Fix Gather CPU kernel
Merge pull request !29585 from zuochuanyong/fix_gather_cpu
This commit is contained in:
commit
93c75b2dc4
|
@ -28,9 +28,23 @@ constexpr size_t kGatherOutputsNum = 1;
|
||||||
constexpr size_t kGatherInputParamsMaxDim = 4;
|
constexpr size_t kGatherInputParamsMaxDim = 4;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void GatherV2CpuKernelMod<T>::CheckParam(const CNodePtr &kernel_node) {
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_num == kGatherInputsNum + 1) {
|
||||||
|
is_dynamic_shape_ = true;
|
||||||
|
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Dynamic Mode.";
|
||||||
|
} else if (input_num == kGatherInputsNum) {
|
||||||
|
MS_LOG(DEBUG) << " GatherV2CPUKernel running in Normal Mode.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void GatherV2CpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
void GatherV2CpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
CheckParam(kernel_node);
|
||||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||||
|
@ -48,7 +62,6 @@ template <typename T>
|
||||||
bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
const std::vector<kernel::AddressPtr> &,
|
const std::vector<kernel::AddressPtr> &,
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherInputsNum, kernel_name_);
|
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherOutputsNum, kernel_name_);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherOutputsNum, kernel_name_);
|
||||||
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
|
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
|
||||||
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||||
|
|
|
@ -31,6 +31,8 @@ class GatherV2CpuKernelMod : public NativeCpuKernelMod {
|
||||||
GatherV2CpuKernelMod() = default;
|
GatherV2CpuKernelMod() = default;
|
||||||
~GatherV2CpuKernelMod() override = default;
|
~GatherV2CpuKernelMod() override = default;
|
||||||
|
|
||||||
|
void CheckParam(const CNodePtr &kernel_node);
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
|
Loading…
Reference in New Issue