add dynamic and vmap support for gatherd
This commit is contained in:
parent
515e54d361
commit
83280c2230
|
@ -82,6 +82,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);
|
||||
|
|
|
@ -66,11 +66,32 @@ void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, const I *index, co
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
bool GatherDCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.size() != 3) {
|
||||
MS_LOG(ERROR) << "GatherD input size must be equal to 3!";
|
||||
return false;
|
||||
}
|
||||
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
|
||||
return ret;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int GatherDCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
const size_t kIndexIdx = 2;
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
auto index_shape = inputs[kIndexIdx]->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
(void)std::transform(index_shape.begin(), index_shape.end(), std::back_inserter(index_shape_), LongToSize);
|
||||
|
||||
if (input_shape_.size() != index_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', shape size of 'x' must be equal to 'index', but got shape size of 'x': "
|
||||
|
@ -78,20 +99,11 @@ void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
output_shape_ = index_shape_;
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, GatherDFunc> &pair) { return pair.first; });
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "GatherD does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_);
|
||||
|
@ -157,67 +169,71 @@ bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::GatherDFunc>> GatherDCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
|
||||
const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> &GatherDCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
|
||||
|
||||
return func_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
@ -23,25 +24,27 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class GatherDCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class GatherDCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GatherDCpuKernelMod> {
|
||||
public:
|
||||
GatherDCpuKernelMod() = default;
|
||||
~GatherDCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
private:
|
||||
template <typename T, typename I>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using GatherDFunc = std::function<bool(GatherDCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, GatherDFunc>> func_list_;
|
||||
GatherDFunc kernel_func_;
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> index_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
|
|
|
@ -15,92 +15,341 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherFwdGpuKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherFwdGpuKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherFwdGpuKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherFwdGpuKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherFwdGpuKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherFwdGpuKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherFwdGpuKernelMod, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherFwdGpuKernelMod, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherFwdGpuKernelMod, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherFwdGpuKernelMod, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherFwdGpuKernelMod, int16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherFwdGpuKernelMod, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherFwdGpuKernelMod, int64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherFwdGpuKernelMod, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherFwdGpuKernelMod, uint, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherFwdGpuKernelMod, uint, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherFwdGpuKernelMod, uchar, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherFwdGpuKernelMod, uchar, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherFwdGpuKernelMod, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
GatherFwdGpuKernelMod, bool, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherFwdGpuKernelMod, uint32_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherFwdGpuKernelMod, uint32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
GatherFwdGpuKernelMod, uint64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
GatherFwdGpuKernelMod, uint64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
GatherFwdGpuKernelMod, uint16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
GatherFwdGpuKernelMod, uint16_t, int64_t)
|
||||
std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
|
||||
// For static shape case:
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
// For dynamic shape case:
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>}};
|
||||
|
||||
bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
|
||||
int64_t x_rank = SizeToLong(input_shapes_.size());
|
||||
if (dim_value < -x_rank || dim_value >= x_rank) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << x_rank << "," << x_rank
|
||||
<< "), but got " << dim_value;
|
||||
return false;
|
||||
}
|
||||
if (dim_value < 0) {
|
||||
dim_value += x_rank;
|
||||
}
|
||||
|
||||
size_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < LongToSize(dim_value); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
size_t dim_at_axis_input = input_shapes_[LongToSize(dim_value)];
|
||||
size_t dim_at_axis_output = output_shapes_[LongToSize(dim_value)];
|
||||
size_t dim_after_axis = 1;
|
||||
for (size_t i = LongToSize(dim_value) + 1; i < output_shapes_.size(); i++) {
|
||||
dim_after_axis *= output_shapes_[i];
|
||||
}
|
||||
|
||||
const size_t k2Idx = 2;
|
||||
const size_t k3Idx = 3;
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_at_axis_input;
|
||||
dims_[k2Idx] = dim_at_axis_output;
|
||||
dims_[k3Idx] = dim_after_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GatherFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
size_t input_num = inputs.size();
|
||||
const size_t kStaticInputNum = 2;
|
||||
const size_t kDynInputNum = 3;
|
||||
if (input_num == kStaticInputNum) {
|
||||
is_dynamic_case_ = false;
|
||||
} else if (input_num == kDynInputNum) {
|
||||
is_dynamic_case_ = true;
|
||||
const size_t kDynIndexIdx = 2;
|
||||
index_idx_ = kDynIndexIdx;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_shapes = inputs[0]->GetShapeVector();
|
||||
auto index_shapes = inputs[index_idx_]->GetShapeVector();
|
||||
auto output_shapes = outputs[0]->GetShapeVector();
|
||||
|
||||
input_shapes_.clear();
|
||||
index_shapes_.clear();
|
||||
output_shapes_.clear();
|
||||
std::transform(input_shapes.cbegin(), input_shapes.cend(), std::back_inserter(input_shapes_), LongToSize);
|
||||
std::transform(index_shapes.cbegin(), index_shapes.cend(), std::back_inserter(index_shapes_), LongToSize);
|
||||
std::transform(output_shapes.cbegin(), output_shapes.cend(), std::back_inserter(output_shapes_), LongToSize);
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
|
||||
CHECK_SHAPE_NULL(index_shapes_, kernel_name_, "input_indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
|
||||
<< "dimension of index: " << index_shapes_.size()
|
||||
<< ", but got the dimension of input: " << input_shapes_.size()
|
||||
<< ", the dimension of output: " << output_shapes_.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
int64_t dim_value = 0;
|
||||
if (!is_dynamic_case_) {
|
||||
const std::string kAttrDim = "dim";
|
||||
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
|
||||
if (dim_attr == nullptr) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
dim_value = GetValue<int64_t>(dim_attr);
|
||||
} else {
|
||||
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
|
||||
}
|
||||
if (!SetDimParam(dim_value)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GatherFwdGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, GatherFwdFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GatherD, GatherFwdGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,124 +14,83 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class GatherFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class GatherFwdGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
GatherFwdGpuKernelMod() : axis_(0), is_null_input_(false) {}
|
||||
GatherFwdGpuKernelMod() {}
|
||||
~GatherFwdGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (!kernel_func_) {
|
||||
MS_LOG(ERROR) << "GatherFwdGpu's kernel function is not initialized.";
|
||||
return false;
|
||||
}
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *index_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID);
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
|
||||
}
|
||||
input_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
index_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shapes_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
|
||||
CHECK_SHAPE_NULL(index_shapes_, kernel_name, "input_indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
|
||||
<< "dimension of index: " << index_shapes_.size()
|
||||
<< ", but got the dimension of input: " << input_shapes_.size()
|
||||
<< ", the dimension of output: " << output_shapes_.size();
|
||||
}
|
||||
int dims = SizeToInt(input_shapes_.size());
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
axis_ += dims;
|
||||
}
|
||||
Reshape();
|
||||
InitSizeLists();
|
||||
auto input_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
auto index_addr = reinterpret_cast<S *>(inputs.at(index_idx_)->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
|
||||
const size_t k2Idx = 2;
|
||||
const size_t k3Idx = 3;
|
||||
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[k2Idx], dims_[k3Idx], cuda_stream,
|
||||
GET_CTX_DEVICE_ID);
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {}
|
||||
void InitSizeLists() override {
|
||||
size_t size = GetSize(input_shapes_, true);
|
||||
input_size_list_.push_back(size);
|
||||
bool SetDimParam(int64_t dim_value);
|
||||
|
||||
size = GetSize(index_shapes_, false);
|
||||
input_size_list_.push_back(size);
|
||||
using GatherFwdFunc =
|
||||
std::function<bool(GatherFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
|
||||
size = GetSize(output_shapes_, true);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
size_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
|
||||
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
|
||||
size_t dim_after_axis = 1;
|
||||
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
|
||||
dim_after_axis *= output_shapes_[i];
|
||||
}
|
||||
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_at_axis_input;
|
||||
dims_[2] = dim_at_axis_output;
|
||||
dims_[3] = dim_after_axis;
|
||||
return;
|
||||
}
|
||||
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
|
||||
size_t result = flag ? sizeof(T) : sizeof(S);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
result *= shape[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
GatherFwdFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, GatherFwdFunc>> func_list_;
|
||||
|
||||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> index_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
size_t dims_[4] = {};
|
||||
int axis_;
|
||||
bool is_null_input_;
|
||||
bool is_null_input_{false};
|
||||
bool is_dynamic_case_{false};
|
||||
size_t index_idx_{1};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
|
||||
|
|
|
@ -385,10 +385,15 @@ std::optional<std::vector<int64_t>> GetDynamicAttrIntValue(
|
|||
<< "th input in the depend_tensor_map";
|
||||
}
|
||||
auto input_tensor = depend_iter->second;
|
||||
const auto &input_shape = inputs[input_index]->GetShapeVector();
|
||||
auto input_shape = inputs[input_index]->GetShapeVector();
|
||||
// The shape keep in depend_tensor_map was processed if it was empty.
|
||||
if (input_shape.empty()) {
|
||||
input_shape.push_back(1);
|
||||
}
|
||||
if (input_shape != input_tensor->shape()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index
|
||||
<< "th input is different between the InferShape and the TensorShape";
|
||||
<< "th input is different between the InferShape and the TensorShape: " << input_shape << " vs "
|
||||
<< input_tensor->shape();
|
||||
}
|
||||
const auto &data_format = inputs[input_index]->GetFormat();
|
||||
if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) {
|
||||
|
|
|
@ -69,6 +69,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
|
||||
static const auto &kGather = prim::kPrimGather->name();
|
||||
static const auto &kGatherV2 = prim::kPrimGatherV2->name();
|
||||
static const auto &kGatherD = prim::kPrimGatherD->name();
|
||||
static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name();
|
||||
static const auto &kRange = prim::kPrimRange->name();
|
||||
static const auto &kRangeV2 = prim::kPrimRangeV2->name();
|
||||
|
@ -101,6 +102,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
{kMatrixSetDiagV3, ShapeSet{2}},
|
||||
{kGather, ShapeSet{2}},
|
||||
{kGatherV2, ShapeSet{2}},
|
||||
{kGatherD, ShapeSet{1}},
|
||||
{kSparseGatherV2, ShapeSet{2}},
|
||||
{kRange, ShapeSet{0, 1, 2}},
|
||||
{kRangeV2, ShapeSet{0, 1, 2}},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,35 +26,97 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
// gather_d
|
||||
namespace {
|
||||
int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(dim_ptr);
|
||||
auto dim_value_ptr = dim_ptr->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dim_value_ptr);
|
||||
auto dim_type_ptr = dim_ptr->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(dim_type_ptr);
|
||||
int64_t dim_v = 0;
|
||||
if (dim_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(dim_tensor);
|
||||
size_t data_size = dim_tensor->DataSize();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
|
||||
auto dim_type_id = dim_type_ptr->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dim_type_id);
|
||||
auto element = dim_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (element->type_id() == kNumberTypeInt32) {
|
||||
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data32);
|
||||
dim_v = static_cast<int64_t>(*dim_data32);
|
||||
} else {
|
||||
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(dim_data64);
|
||||
dim_v = static_cast<int64_t>(*dim_data64);
|
||||
}
|
||||
} else {
|
||||
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
|
||||
dim_v = GetValue<int64_t>(dim_value_ptr);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
|
||||
}
|
||||
}
|
||||
|
||||
return dim_v;
|
||||
}
|
||||
|
||||
bool IsShapeInValid(const ShapeVector &shape) {
|
||||
return std::any_of(shape.cbegin(), shape.cend(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
void CheckGatherShapeEqual(const std::string &prim_name, const ShapeVector &x_shape, int64_t dim_v,
|
||||
const ShapeVector &index_shape) {
|
||||
if (IsShapeInValid(x_shape) || IsShapeInValid(index_shape)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (SizeToLong(i) == dim_v) continue;
|
||||
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
|
||||
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
|
||||
const size_t gather_d_input_num = 3;
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == gather_d_input_num,
|
||||
"GatherD's input size should be 3 but got " + std::to_string(input_args.size()));
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex0]->BuildShape()->isa<abstract::Shape>(), "x's shape wrong.");
|
||||
auto shape_element = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
auto x_shape = shape_element->shape();
|
||||
auto x_min_shape = shape_element->min_shape();
|
||||
auto x_max_shape = shape_element->max_shape();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex2]->BuildShape()->isa<abstract::Shape>(), "index's shape wrong.");
|
||||
auto index_shape_element = input_args[kInputIndex2]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
auto index_shape = index_shape_element->shape();
|
||||
auto index_min_shape = index_shape_element->min_shape();
|
||||
auto index_max_shape = index_shape_element->max_shape();
|
||||
int64_t x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name);
|
||||
auto value_ptr = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto dim_v = GetValue<int64_t>(value_ptr);
|
||||
auto dim_v = GetGatherDimValue(input_args[kInputIndex1]);
|
||||
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name);
|
||||
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name);
|
||||
|
||||
if (dim_v < 0) {
|
||||
dim_v = dim_v + x_rank;
|
||||
}
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (SizeToLong(i) == dim_v) continue;
|
||||
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
|
||||
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(index_shape);
|
||||
|
||||
// For Ascend, only support x.shape[d] == index.shape[d] when d != dim. So limit it.
|
||||
CheckGatherShapeEqual(prim_name, x_shape, dim_v, index_shape);
|
||||
CheckGatherShapeEqual(prim_name, x_min_shape, dim_v, index_min_shape);
|
||||
CheckGatherShapeEqual(prim_name, x_max_shape, dim_v, index_max_shape);
|
||||
return std::make_shared<abstract::Shape>(index_shape, index_min_shape, index_max_shape);
|
||||
}
|
||||
|
||||
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name);
|
||||
return x_type;
|
||||
|
@ -70,10 +132,12 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types,
|
||||
std::set<TypePtr> index_valid_types = {kInt32, kInt64};
|
||||
std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
|
||||
std::make_shared<TensorType>(kInt64)};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
|
||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import constexpr
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
|
||||
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
|
||||
_handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis
|
||||
from ..operations.array_ops import Fills
|
||||
from ..operations.array_ops import UniqueConsecutive
|
||||
|
||||
|
@ -541,6 +541,48 @@ def get_ger_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.GatherD)
|
||||
def get_gatherd_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for GatherD operations."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(x_bdim, dim_bdim, index_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
dim_value, axis_dim = dim_bdim
|
||||
index, index_dim = index_bdim
|
||||
|
||||
# `dim` will be a Tensor in dynamic shape case, do not support its vamp.
|
||||
if axis_dim is not None:
|
||||
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
|
||||
"but got {}.".format(axis_dim))
|
||||
if not isinstance(dim_value, int):
|
||||
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
|
||||
|
||||
out_dim = index_dim
|
||||
|
||||
# Broadcast if needed.
|
||||
if x_dim is None:
|
||||
x = _broadcast_by_axis(x, index_dim, axis_size)
|
||||
elif index_dim is None:
|
||||
index = _broadcast_by_axis(index, x_dim, axis_size)
|
||||
out_dim = x_dim
|
||||
elif x_dim != index_dim:
|
||||
mnp.moveaxis(x, x_dim, index_dim)
|
||||
|
||||
# Adapt `dim` to vmap case.
|
||||
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
|
||||
|
||||
out = prim(x, dim_value, index)
|
||||
return (out, out_dim)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.SpaceToBatchND)
|
||||
def get_space_to_batch_nd_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `SpaceToBatchND`."""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,14 +13,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, dim=0):
|
||||
super(Net, self).__init__()
|
||||
|
@ -40,26 +46,89 @@ class NetGrad(nn.Cell):
|
|||
return self.op(index, x)
|
||||
|
||||
|
||||
def test_net():
|
||||
def get_data(ms_type):
|
||||
x = Tensor(np.array([[772, 231, 508, 545, 615, 249],
|
||||
[923, 210, 480, 696, 482, 761],
|
||||
[465, 904, 521, 824, 607, 669],
|
||||
[156, 539, 56, 159, 916, 566],
|
||||
[122, 676, 714, 261, 19, 936]]), mindspore.int32)
|
||||
index = Tensor(np.array([[0, 0, 0, 1, 1],
|
||||
[0, 0, 0, 1, 4],
|
||||
[0, 0, 0, 1, -1],
|
||||
[1, 1, 1, 0, 0]]), mindspore.int32)
|
||||
[122, 676, 714, 261, 19, 936]]), ms_type)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 0, 1, 0, -4],
|
||||
[0, 2, 0, 2, 0, -3],
|
||||
[0, 0, 0, 3, 3, -2],
|
||||
[4, 4, 4, 0, 0, -1],
|
||||
[4, 3, 2, 1, -1, -2]]), mindspore.int32)
|
||||
expect = np.array([[772, 210, 508, 696, 615, 761],
|
||||
[772, 904, 508, 824, 615, 669],
|
||||
[772, 231, 508, 159, 916, 566],
|
||||
[122, 676, 714, 545, 615, 936],
|
||||
[122, 539, 521, 696, 19, 566]])
|
||||
res = (x, dim, index, expect)
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
|
||||
def test_net(ms_type):
|
||||
"""
|
||||
Feature: test GatherD static shape.
|
||||
Description: input x and index is static shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x, dim, index, expect = get_data(ms_type)
|
||||
net = Net(dim)
|
||||
out = net(x, index)
|
||||
print(out.asnumpy())
|
||||
|
||||
expect_out = np.array([[772, 231, 508, 696, 482],
|
||||
[772, 231, 508, 696, 19],
|
||||
[772, 231, 508, 696, 19],
|
||||
[923, 210, 480, 545, 615]])
|
||||
assert np.array_equal(out.asnumpy(), expect_out)
|
||||
assert np.array_equal(out.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
|
||||
def test_gatherd_dynamic(ms_type):
|
||||
"""
|
||||
Feature: test GatherD dynamic shape.
|
||||
Description: index is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x, dim, index, expect = get_data(ms_type)
|
||||
index_dyn = Tensor(shape=[index.shape[0], None], dtype=mindspore.int32)
|
||||
net = Net(dim)
|
||||
net.set_inputs(x, index_dyn)
|
||||
out = net(x, index)
|
||||
|
||||
assert np.array_equal(out.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.int32, np.uint32, np.float32])
|
||||
def test_gatherd_vmap(dtype):
|
||||
"""
|
||||
Feature: test GatherD vmap interface.
|
||||
Description: input x and index is static shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
def cal_gatherd(x, dim, index):
|
||||
return P.GatherD()(x, dim, index)
|
||||
|
||||
gather_dim = 1
|
||||
x = Tensor(np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]).astype(dtype))
|
||||
y = Tensor(np.array([[[0, 0], [1, 0]], [[0, 0], [1, 0]], [[0, 0], [1, 0]]]).astype(np.int32))
|
||||
outputs = vmap(cal_gatherd, in_axes=(0, None, 0), out_axes=0)(x, gather_dim, y)
|
||||
expect = np.array([[[1, 1], [4, 3]], [[1, 1], [4, 3]], [[1, 1], [4, 3]]]).astype(dtype)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
def test_net_bool():
|
||||
|
|
Loading…
Reference in New Issue