add dynamic and vmap support for gatherd
This commit is contained in:
parent
515e54d361
commit
83280c2230
|
@ -82,6 +82,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs) {
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);
|
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);
|
||||||
|
|
|
@ -66,11 +66,32 @@ void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, const I *index, co
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
bool GatherDCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
kernel_name_ = base_operator->name();
|
||||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
if (inputs.size() != 3) {
|
||||||
index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
MS_LOG(ERROR) << "GatherD input size must be equal to 3!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GatherDCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t kIndexIdx = 2;
|
||||||
|
auto input_shape = inputs[0]->GetShapeVector();
|
||||||
|
auto index_shape = inputs[kIndexIdx]->GetShapeVector();
|
||||||
|
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||||
|
(void)std::transform(index_shape.begin(), index_shape.end(), std::back_inserter(index_shape_), LongToSize);
|
||||||
|
|
||||||
if (input_shape_.size() != index_shape_.size()) {
|
if (input_shape_.size() != index_shape_.size()) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||||
<< "', shape size of 'x' must be equal to 'index', but got shape size of 'x': "
|
<< "', shape size of 'x' must be equal to 'index', but got shape size of 'x': "
|
||||||
|
@ -78,20 +99,11 @@ void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
}
|
}
|
||||||
output_shape_ = index_shape_;
|
output_shape_ = index_shape_;
|
||||||
|
|
||||||
std::vector<KernelAttr> support_list;
|
return KRET_OK;
|
||||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
|
||||||
[](const std::pair<KernelAttr, GatherDFunc> &pair) { return pair.first; });
|
|
||||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
|
||||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
|
||||||
if (!is_match) {
|
|
||||||
MS_LOG(EXCEPTION) << "GatherD does not support this kernel data type: " << kernel_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_func_ = func_list_[index].second;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_);
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_);
|
||||||
|
@ -157,67 +169,71 @@ bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::GatherDFunc>> GatherDCpuKernelMod::func_list_ = {
|
const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> &GatherDCpuKernelMod::GetFuncList() const {
|
||||||
{KernelAttr()
|
static const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> func_list = {
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||||
.AddInputAttr(kNumberTypeBool)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeBool)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeBool),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
|
.AddOutputAttr(kNumberTypeBool),
|
||||||
{KernelAttr()
|
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
|
||||||
.AddInputAttr(kNumberTypeBool)
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeBool)
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeBool),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
|
.AddOutputAttr(kNumberTypeBool),
|
||||||
|
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
|
||||||
|
|
||||||
|
return func_list;
|
||||||
|
}
|
||||||
|
|
||||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod);
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
@ -23,25 +24,27 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class GatherDCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
class GatherDCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GatherDCpuKernelMod> {
|
||||||
public:
|
public:
|
||||||
GatherDCpuKernelMod() = default;
|
GatherDCpuKernelMod() = default;
|
||||||
~GatherDCpuKernelMod() override = default;
|
~GatherDCpuKernelMod() override = default;
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
int Resize(
|
||||||
|
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) override {
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
return kernel_func_(this, inputs, outputs);
|
return kernel_func_(this, inputs, workspace, outputs);
|
||||||
}
|
}
|
||||||
|
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
using GatherDFunc = std::function<bool(GatherDCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
const std::vector<kernel::AddressPtr> &outputs);
|
||||||
const std::vector<kernel::AddressPtr> &)>;
|
|
||||||
static std::vector<std::pair<KernelAttr, GatherDFunc>> func_list_;
|
|
||||||
GatherDFunc kernel_func_;
|
|
||||||
std::vector<size_t> input_shape_;
|
std::vector<size_t> input_shape_;
|
||||||
std::vector<size_t> index_shape_;
|
std::vector<size_t> index_shape_;
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
|
|
|
@ -15,92 +15,341 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h"
|
#include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
|
||||||
GatherD,
|
// For static shape case:
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||||
GatherFwdGpuKernelMod, double, int)
|
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
GatherD,
|
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
GatherFwdGpuKernelMod, double, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||||
GatherD,
|
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||||
GatherFwdGpuKernelMod, float, int)
|
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||||
GatherD,
|
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
GatherFwdGpuKernelMod, float, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||||
GatherD,
|
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||||
GatherFwdGpuKernelMod, half, int)
|
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||||
GatherD,
|
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||||
GatherFwdGpuKernelMod, half, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||||
GatherFwdGpuKernelMod, int, int)
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
GatherFwdGpuKernelMod, int, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
|
||||||
GatherFwdGpuKernelMod, int8_t, int)
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||||
GatherFwdGpuKernelMod, int8_t, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
|
||||||
GatherFwdGpuKernelMod, int16_t, int)
|
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||||
GatherFwdGpuKernelMod, int16_t, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
|
||||||
GatherFwdGpuKernelMod, int64_t, int)
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||||
GatherFwdGpuKernelMod, int64_t, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||||
GatherFwdGpuKernelMod, uint, int)
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||||
GatherFwdGpuKernelMod, uint, int64_t)
|
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
// For dynamic shape case:
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
{KernelAttr()
|
||||||
GatherFwdGpuKernelMod, uchar, int)
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
GatherFwdGpuKernelMod, uchar, int64_t)
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
{KernelAttr()
|
||||||
GatherFwdGpuKernelMod, bool, int)
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherFwdGpuKernelMod, bool, int64_t)
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
{KernelAttr()
|
||||||
GatherFwdGpuKernelMod, uint32_t, int)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
GatherFwdGpuKernelMod, uint32_t, int64_t)
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
{KernelAttr()
|
||||||
GatherFwdGpuKernelMod, uint64_t, int)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherFwdGpuKernelMod, uint64_t, int64_t)
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
{KernelAttr()
|
||||||
GatherFwdGpuKernelMod, uint16_t, int)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
MS_REG_GPU_KERNEL_TWO(
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
GatherFwdGpuKernelMod, uint16_t, int64_t)
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeBool),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeBool)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeBool),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>}};
|
||||||
|
|
||||||
|
bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
|
||||||
|
int64_t x_rank = SizeToLong(input_shapes_.size());
|
||||||
|
if (dim_value < -x_rank || dim_value >= x_rank) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << x_rank << "," << x_rank
|
||||||
|
<< "), but got " << dim_value;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (dim_value < 0) {
|
||||||
|
dim_value += x_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t dim_before_axis = 1;
|
||||||
|
for (size_t i = 0; i < LongToSize(dim_value); i++) {
|
||||||
|
dim_before_axis *= output_shapes_[i];
|
||||||
|
}
|
||||||
|
size_t dim_at_axis_input = input_shapes_[LongToSize(dim_value)];
|
||||||
|
size_t dim_at_axis_output = output_shapes_[LongToSize(dim_value)];
|
||||||
|
size_t dim_after_axis = 1;
|
||||||
|
for (size_t i = LongToSize(dim_value) + 1; i < output_shapes_.size(); i++) {
|
||||||
|
dim_after_axis *= output_shapes_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t k2Idx = 2;
|
||||||
|
const size_t k3Idx = 3;
|
||||||
|
dims_[0] = dim_before_axis;
|
||||||
|
dims_[1] = dim_at_axis_input;
|
||||||
|
dims_[k2Idx] = dim_at_axis_output;
|
||||||
|
dims_[k3Idx] = dim_after_axis;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GatherFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
kernel_name_ = base_operator->name();
|
||||||
|
size_t input_num = inputs.size();
|
||||||
|
const size_t kStaticInputNum = 2;
|
||||||
|
const size_t kDynInputNum = 3;
|
||||||
|
if (input_num == kStaticInputNum) {
|
||||||
|
is_dynamic_case_ = false;
|
||||||
|
} else if (input_num == kDynInputNum) {
|
||||||
|
is_dynamic_case_ = true;
|
||||||
|
const size_t kDynIndexIdx = 2;
|
||||||
|
index_idx_ = kDynIndexIdx;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||||
|
if (ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_shapes = inputs[0]->GetShapeVector();
|
||||||
|
auto index_shapes = inputs[index_idx_]->GetShapeVector();
|
||||||
|
auto output_shapes = outputs[0]->GetShapeVector();
|
||||||
|
|
||||||
|
input_shapes_.clear();
|
||||||
|
index_shapes_.clear();
|
||||||
|
output_shapes_.clear();
|
||||||
|
std::transform(input_shapes.cbegin(), input_shapes.cend(), std::back_inserter(input_shapes_), LongToSize);
|
||||||
|
std::transform(index_shapes.cbegin(), index_shapes.cend(), std::back_inserter(index_shapes_), LongToSize);
|
||||||
|
std::transform(output_shapes.cbegin(), output_shapes.cend(), std::back_inserter(output_shapes_), LongToSize);
|
||||||
|
|
||||||
|
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
|
||||||
|
CHECK_SHAPE_NULL(index_shapes_, kernel_name_, "input_indices") ||
|
||||||
|
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
|
||||||
|
if (is_null_input_) {
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
|
||||||
|
<< "dimension of index: " << index_shapes_.size()
|
||||||
|
<< ", but got the dimension of input: " << input_shapes_.size()
|
||||||
|
<< ", the dimension of output: " << output_shapes_.size();
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim_value = 0;
|
||||||
|
if (!is_dynamic_case_) {
|
||||||
|
const std::string kAttrDim = "dim";
|
||||||
|
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
|
||||||
|
if (dim_attr == nullptr) {
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
dim_value = GetValue<int64_t>(dim_attr);
|
||||||
|
} else {
|
||||||
|
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
|
||||||
|
}
|
||||||
|
if (!SetDimParam(dim_value)) {
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<KernelAttr> GatherFwdGpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, GatherFwdFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GatherD, GatherFwdGpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,124 +14,83 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
#include "plugin/factory/ms_factory.h"
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
template <typename T, typename S>
|
class GatherFwdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
class GatherFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|
||||||
public:
|
public:
|
||||||
GatherFwdGpuKernelMod() : axis_(0), is_null_input_(false) {}
|
GatherFwdGpuKernelMod() {}
|
||||||
~GatherFwdGpuKernelMod() = default;
|
~GatherFwdGpuKernelMod() = default;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
if (!kernel_func_) {
|
||||||
|
MS_LOG(ERROR) << "GatherFwdGpu's kernel function is not initialized.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T, typename S>
|
||||||
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
VARIABLE_NOT_USED(workspace);
|
VARIABLE_NOT_USED(workspace);
|
||||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
|
||||||
S *index_addr = GetDeviceAddress<S>(inputs, 1);
|
|
||||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
|
||||||
|
|
||||||
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
|
auto input_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID);
|
auto index_addr = reinterpret_cast<S *>(inputs.at(index_idx_)->addr);
|
||||||
return true;
|
auto output_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||||
}
|
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
|
||||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
const size_t k2Idx = 2;
|
||||||
kernel_node_ = kernel_node;
|
const size_t k3Idx = 3;
|
||||||
InitResource();
|
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[k2Idx], dims_[k3Idx], cuda_stream,
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
GET_CTX_DEVICE_ID);
|
||||||
if (input_num != 2) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
|
|
||||||
}
|
|
||||||
input_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
||||||
index_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
||||||
output_shapes_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
||||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
|
|
||||||
CHECK_SHAPE_NULL(index_shapes_, kernel_name, "input_indices") ||
|
|
||||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
|
|
||||||
if (is_null_input_) {
|
|
||||||
InitSizeLists();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
|
|
||||||
<< "dimension of index: " << index_shapes_.size()
|
|
||||||
<< ", but got the dimension of input: " << input_shapes_.size()
|
|
||||||
<< ", the dimension of output: " << output_shapes_.size();
|
|
||||||
}
|
|
||||||
int dims = SizeToInt(input_shapes_.size());
|
|
||||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
|
|
||||||
if (axis_ < -dims || axis_ >= dims) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
|
||||||
<< "), but got " << axis_;
|
|
||||||
}
|
|
||||||
if (axis_ < 0) {
|
|
||||||
axis_ += dims;
|
|
||||||
}
|
|
||||||
Reshape();
|
|
||||||
InitSizeLists();
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
bool SetDimParam(int64_t dim_value);
|
||||||
void InitResource() override {}
|
|
||||||
void InitSizeLists() override {
|
|
||||||
size_t size = GetSize(input_shapes_, true);
|
|
||||||
input_size_list_.push_back(size);
|
|
||||||
|
|
||||||
size = GetSize(index_shapes_, false);
|
using GatherFwdFunc =
|
||||||
input_size_list_.push_back(size);
|
std::function<bool(GatherFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||||
|
|
||||||
size = GetSize(output_shapes_, true);
|
GatherFwdFunc kernel_func_;
|
||||||
output_size_list_.push_back(size);
|
static std::vector<std::pair<KernelAttr, GatherFwdFunc>> func_list_;
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void Reshape() {
|
|
||||||
size_t dim_before_axis = 1;
|
|
||||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
|
||||||
dim_before_axis *= output_shapes_[i];
|
|
||||||
}
|
|
||||||
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
|
|
||||||
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
|
|
||||||
size_t dim_after_axis = 1;
|
|
||||||
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
|
|
||||||
dim_after_axis *= output_shapes_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
dims_[0] = dim_before_axis;
|
|
||||||
dims_[1] = dim_at_axis_input;
|
|
||||||
dims_[2] = dim_at_axis_output;
|
|
||||||
dims_[3] = dim_after_axis;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
|
|
||||||
size_t result = flag ? sizeof(T) : sizeof(S);
|
|
||||||
for (size_t i = 0; i < shape.size(); i++) {
|
|
||||||
result *= shape[i];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> input_shapes_;
|
std::vector<size_t> input_shapes_;
|
||||||
std::vector<size_t> index_shapes_;
|
std::vector<size_t> index_shapes_;
|
||||||
std::vector<size_t> output_shapes_;
|
std::vector<size_t> output_shapes_;
|
||||||
|
|
||||||
size_t dims_[4] = {};
|
size_t dims_[4] = {};
|
||||||
int axis_;
|
bool is_null_input_{false};
|
||||||
bool is_null_input_;
|
bool is_dynamic_case_{false};
|
||||||
|
size_t index_idx_{1};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||||
|
|
|
@ -385,10 +385,15 @@ std::optional<std::vector<int64_t>> GetDynamicAttrIntValue(
|
||||||
<< "th input in the depend_tensor_map";
|
<< "th input in the depend_tensor_map";
|
||||||
}
|
}
|
||||||
auto input_tensor = depend_iter->second;
|
auto input_tensor = depend_iter->second;
|
||||||
const auto &input_shape = inputs[input_index]->GetShapeVector();
|
auto input_shape = inputs[input_index]->GetShapeVector();
|
||||||
|
// The shape keep in depend_tensor_map was processed if it was empty.
|
||||||
|
if (input_shape.empty()) {
|
||||||
|
input_shape.push_back(1);
|
||||||
|
}
|
||||||
if (input_shape != input_tensor->shape()) {
|
if (input_shape != input_tensor->shape()) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index
|
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index
|
||||||
<< "th input is different between the InferShape and the TensorShape";
|
<< "th input is different between the InferShape and the TensorShape: " << input_shape << " vs "
|
||||||
|
<< input_tensor->shape();
|
||||||
}
|
}
|
||||||
const auto &data_format = inputs[input_index]->GetFormat();
|
const auto &data_format = inputs[input_index]->GetFormat();
|
||||||
if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) {
|
if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) {
|
||||||
|
|
|
@ -69,6 +69,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
|
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
|
||||||
static const auto &kGather = prim::kPrimGather->name();
|
static const auto &kGather = prim::kPrimGather->name();
|
||||||
static const auto &kGatherV2 = prim::kPrimGatherV2->name();
|
static const auto &kGatherV2 = prim::kPrimGatherV2->name();
|
||||||
|
static const auto &kGatherD = prim::kPrimGatherD->name();
|
||||||
static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name();
|
static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name();
|
||||||
static const auto &kRange = prim::kPrimRange->name();
|
static const auto &kRange = prim::kPrimRange->name();
|
||||||
static const auto &kRangeV2 = prim::kPrimRangeV2->name();
|
static const auto &kRangeV2 = prim::kPrimRangeV2->name();
|
||||||
|
@ -101,6 +102,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
{kMatrixSetDiagV3, ShapeSet{2}},
|
{kMatrixSetDiagV3, ShapeSet{2}},
|
||||||
{kGather, ShapeSet{2}},
|
{kGather, ShapeSet{2}},
|
||||||
{kGatherV2, ShapeSet{2}},
|
{kGatherV2, ShapeSet{2}},
|
||||||
|
{kGatherD, ShapeSet{1}},
|
||||||
{kSparseGatherV2, ShapeSet{2}},
|
{kSparseGatherV2, ShapeSet{2}},
|
||||||
{kRange, ShapeSet{0, 1, 2}},
|
{kRange, ShapeSet{0, 1, 2}},
|
||||||
{kRangeV2, ShapeSet{0, 1, 2}},
|
{kRangeV2, ShapeSet{0, 1, 2}},
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,35 +26,97 @@ namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
// gather_d
|
// gather_d
|
||||||
namespace {
|
namespace {
|
||||||
|
int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_ptr);
|
||||||
|
auto dim_value_ptr = dim_ptr->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_value_ptr);
|
||||||
|
auto dim_type_ptr = dim_ptr->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_type_ptr);
|
||||||
|
int64_t dim_v = 0;
|
||||||
|
if (dim_value_ptr->isa<tensor::Tensor>()) {
|
||||||
|
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_tensor);
|
||||||
|
size_t data_size = dim_tensor->DataSize();
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
|
||||||
|
auto dim_type_id = dim_type_ptr->cast<TensorTypePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_type_id);
|
||||||
|
auto element = dim_type_id->element();
|
||||||
|
MS_EXCEPTION_IF_NULL(element);
|
||||||
|
if (element->type_id() == kNumberTypeInt32) {
|
||||||
|
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_data32);
|
||||||
|
dim_v = static_cast<int64_t>(*dim_data32);
|
||||||
|
} else {
|
||||||
|
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(dim_data64);
|
||||||
|
dim_v = static_cast<int64_t>(*dim_data64);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
|
||||||
|
dim_v = GetValue<int64_t>(dim_value_ptr);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dim_v;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsShapeInValid(const ShapeVector &shape) {
|
||||||
|
return std::any_of(shape.cbegin(), shape.cend(), [](int64_t s) { return s < 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckGatherShapeEqual(const std::string &prim_name, const ShapeVector &x_shape, int64_t dim_v,
|
||||||
|
const ShapeVector &index_shape) {
|
||||||
|
if (IsShapeInValid(x_shape) || IsShapeInValid(index_shape)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||||
|
if (SizeToLong(i) == dim_v) continue;
|
||||||
|
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
|
||||||
|
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
const size_t gather_d_input_num = 3;
|
||||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == gather_d_input_num,
|
||||||
|
"GatherD's input size should be 3 but got " + std::to_string(input_args.size()));
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex0]->BuildShape()->isa<abstract::Shape>(), "x's shape wrong.");
|
||||||
|
auto shape_element = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
|
auto x_shape = shape_element->shape();
|
||||||
|
auto x_min_shape = shape_element->min_shape();
|
||||||
|
auto x_max_shape = shape_element->max_shape();
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex2]->BuildShape()->isa<abstract::Shape>(), "index's shape wrong.");
|
||||||
|
auto index_shape_element = input_args[kInputIndex2]->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
|
auto index_shape = index_shape_element->shape();
|
||||||
|
auto index_min_shape = index_shape_element->min_shape();
|
||||||
|
auto index_max_shape = index_shape_element->max_shape();
|
||||||
int64_t x_rank = SizeToLong(x_shape.size());
|
int64_t x_rank = SizeToLong(x_shape.size());
|
||||||
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name);
|
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name);
|
||||||
auto value_ptr = input_args[1]->BuildValue();
|
auto dim_v = GetGatherDimValue(input_args[kInputIndex1]);
|
||||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
|
||||||
auto dim_v = GetValue<int64_t>(value_ptr);
|
|
||||||
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name);
|
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name);
|
||||||
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name);
|
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name);
|
||||||
|
|
||||||
if (dim_v < 0) {
|
if (dim_v < 0) {
|
||||||
dim_v = dim_v + x_rank;
|
dim_v = dim_v + x_rank;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
|
||||||
if (SizeToLong(i) == dim_v) continue;
|
// For Ascend, only support x.shape[d] == index.shape[d] when d != dim. So limit it.
|
||||||
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
|
CheckGatherShapeEqual(prim_name, x_shape, dim_v, index_shape);
|
||||||
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
|
CheckGatherShapeEqual(prim_name, x_min_shape, dim_v, index_min_shape);
|
||||||
}
|
CheckGatherShapeEqual(prim_name, x_max_shape, dim_v, index_max_shape);
|
||||||
return std::make_shared<abstract::Shape>(index_shape);
|
return std::make_shared<abstract::Shape>(index_shape, index_min_shape, index_max_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto prim_name = prim->name();
|
auto prim_name = prim->name();
|
||||||
// check
|
|
||||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||||
auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name);
|
auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name);
|
||||||
return x_type;
|
return x_type;
|
||||||
|
@ -70,10 +132,12 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
||||||
}
|
}
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
// check
|
||||||
std::set<TypePtr> valid_types = {kInt32, kInt64};
|
std::set<TypePtr> index_valid_types = {kInt32, kInt64};
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types,
|
std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
|
||||||
|
std::make_shared<TensorType>(kInt64)};
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
|
||||||
prim_name);
|
prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
|
||||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import constexpr
|
from mindspore.ops import constexpr
|
||||||
from ..primitive import Primitive
|
from ..primitive import Primitive
|
||||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
|
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
|
||||||
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
|
_handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis
|
||||||
from ..operations.array_ops import Fills
|
from ..operations.array_ops import Fills
|
||||||
from ..operations.array_ops import UniqueConsecutive
|
from ..operations.array_ops import UniqueConsecutive
|
||||||
|
|
||||||
|
@ -541,6 +541,48 @@ def get_ger_vmap_rule(prim, axis_size):
|
||||||
return vmap_rule
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
|
@vmap_rules_getters.register(P.GatherD)
|
||||||
|
def get_gatherd_vmap_rule(prim, axis_size):
|
||||||
|
"""VmapRule for GatherD operations."""
|
||||||
|
if isinstance(prim, str):
|
||||||
|
prim = Primitive(prim)
|
||||||
|
|
||||||
|
def vmap_rule(x_bdim, dim_bdim, index_bdim):
|
||||||
|
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim)
|
||||||
|
if is_all_none:
|
||||||
|
return result
|
||||||
|
|
||||||
|
x, x_dim = x_bdim
|
||||||
|
dim_value, axis_dim = dim_bdim
|
||||||
|
index, index_dim = index_bdim
|
||||||
|
|
||||||
|
# `dim` will be a Tensor in dynamic shape case, do not support its vamp.
|
||||||
|
if axis_dim is not None:
|
||||||
|
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
|
||||||
|
"but got {}.".format(axis_dim))
|
||||||
|
if not isinstance(dim_value, int):
|
||||||
|
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
|
||||||
|
|
||||||
|
out_dim = index_dim
|
||||||
|
|
||||||
|
# Broadcast if needed.
|
||||||
|
if x_dim is None:
|
||||||
|
x = _broadcast_by_axis(x, index_dim, axis_size)
|
||||||
|
elif index_dim is None:
|
||||||
|
index = _broadcast_by_axis(index, x_dim, axis_size)
|
||||||
|
out_dim = x_dim
|
||||||
|
elif x_dim != index_dim:
|
||||||
|
mnp.moveaxis(x, x_dim, index_dim)
|
||||||
|
|
||||||
|
# Adapt `dim` to vmap case.
|
||||||
|
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
|
||||||
|
|
||||||
|
out = prim(x, dim_value, index)
|
||||||
|
return (out, out_dim)
|
||||||
|
|
||||||
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
@vmap_rules_getters.register(P.SpaceToBatchND)
|
@vmap_rules_getters.register(P.SpaceToBatchND)
|
||||||
def get_space_to_batch_nd_vmap_rule(prim, axis_size):
|
def get_space_to_batch_nd_vmap_rule(prim, axis_size):
|
||||||
"""VmapRule for `SpaceToBatchND`."""
|
"""VmapRule for `SpaceToBatchND`."""
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,14 +13,20 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from mindspore.ops.functional import vmap
|
||||||
|
|
||||||
|
|
||||||
|
context.set_context(device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, dim=0):
|
def __init__(self, dim=0):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -40,26 +46,89 @@ class NetGrad(nn.Cell):
|
||||||
return self.op(index, x)
|
return self.op(index, x)
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def get_data(ms_type):
|
||||||
x = Tensor(np.array([[772, 231, 508, 545, 615, 249],
|
x = Tensor(np.array([[772, 231, 508, 545, 615, 249],
|
||||||
[923, 210, 480, 696, 482, 761],
|
[923, 210, 480, 696, 482, 761],
|
||||||
[465, 904, 521, 824, 607, 669],
|
[465, 904, 521, 824, 607, 669],
|
||||||
[156, 539, 56, 159, 916, 566],
|
[156, 539, 56, 159, 916, 566],
|
||||||
[122, 676, 714, 261, 19, 936]]), mindspore.int32)
|
[122, 676, 714, 261, 19, 936]]), ms_type)
|
||||||
index = Tensor(np.array([[0, 0, 0, 1, 1],
|
|
||||||
[0, 0, 0, 1, 4],
|
|
||||||
[0, 0, 0, 1, -1],
|
|
||||||
[1, 1, 1, 0, 0]]), mindspore.int32)
|
|
||||||
dim = 0
|
dim = 0
|
||||||
|
index = Tensor(np.array([[0, 1, 0, 1, 0, -4],
|
||||||
|
[0, 2, 0, 2, 0, -3],
|
||||||
|
[0, 0, 0, 3, 3, -2],
|
||||||
|
[4, 4, 4, 0, 0, -1],
|
||||||
|
[4, 3, 2, 1, -1, -2]]), mindspore.int32)
|
||||||
|
expect = np.array([[772, 210, 508, 696, 615, 761],
|
||||||
|
[772, 904, 508, 824, 615, 669],
|
||||||
|
[772, 231, 508, 159, 916, 566],
|
||||||
|
[122, 676, 714, 545, 615, 936],
|
||||||
|
[122, 539, 521, 696, 19, 566]])
|
||||||
|
res = (x, dim, index, expect)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
|
||||||
|
def test_net(ms_type):
|
||||||
|
"""
|
||||||
|
Feature: test GatherD static shape.
|
||||||
|
Description: input x and index is static shape.
|
||||||
|
Expectation: the result match with numpy result
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x, dim, index, expect = get_data(ms_type)
|
||||||
net = Net(dim)
|
net = Net(dim)
|
||||||
out = net(x, index)
|
out = net(x, index)
|
||||||
print(out.asnumpy())
|
|
||||||
|
|
||||||
expect_out = np.array([[772, 231, 508, 696, 482],
|
assert np.array_equal(out.asnumpy(), expect)
|
||||||
[772, 231, 508, 696, 19],
|
|
||||||
[772, 231, 508, 696, 19],
|
|
||||||
[923, 210, 480, 545, 615]])
|
@pytest.mark.level0
|
||||||
assert np.array_equal(out.asnumpy(), expect_out)
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
|
||||||
|
def test_gatherd_dynamic(ms_type):
|
||||||
|
"""
|
||||||
|
Feature: test GatherD dynamic shape.
|
||||||
|
Description: index is dynamic shape.
|
||||||
|
Expectation: the result match with numpy result
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x, dim, index, expect = get_data(ms_type)
|
||||||
|
index_dyn = Tensor(shape=[index.shape[0], None], dtype=mindspore.int32)
|
||||||
|
net = Net(dim)
|
||||||
|
net.set_inputs(x, index_dyn)
|
||||||
|
out = net(x, index)
|
||||||
|
|
||||||
|
assert np.array_equal(out.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('dtype', [np.int32, np.uint32, np.float32])
|
||||||
|
def test_gatherd_vmap(dtype):
|
||||||
|
"""
|
||||||
|
Feature: test GatherD vmap interface.
|
||||||
|
Description: input x and index is static shape.
|
||||||
|
Expectation: the result match with numpy result
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
def cal_gatherd(x, dim, index):
|
||||||
|
return P.GatherD()(x, dim, index)
|
||||||
|
|
||||||
|
gather_dim = 1
|
||||||
|
x = Tensor(np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]).astype(dtype))
|
||||||
|
y = Tensor(np.array([[[0, 0], [1, 0]], [[0, 0], [1, 0]], [[0, 0], [1, 0]]]).astype(np.int32))
|
||||||
|
outputs = vmap(cal_gatherd, in_axes=(0, None, 0), out_axes=0)(x, gather_dim, y)
|
||||||
|
expect = np.array([[[1, 1], [4, 3]], [[1, 1], [4, 3]], [[1, 1], [4, 3]]]).astype(dtype)
|
||||||
|
assert np.allclose(outputs.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
def test_net_bool():
|
def test_net_bool():
|
||||||
|
|
Loading…
Reference in New Issue