add dynamic and vmap support for gatherd

This commit is contained in:
tronzhang 2022-05-26 14:35:12 +08:00
parent 515e54d361
commit 83280c2230
10 changed files with 702 additions and 292 deletions

View File

@ -82,6 +82,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
}
return true;
}
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);

View File

@ -66,11 +66,32 @@ void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, const I *index, co
}
} // namespace
void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
bool GatherDCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.size() != 3) {
MS_LOG(ERROR) << "GatherD input size must be equal to 3!";
return false;
}
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
return ret;
}
return true;
}
int GatherDCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
const size_t kIndexIdx = 2;
auto input_shape = inputs[0]->GetShapeVector();
auto index_shape = inputs[kIndexIdx]->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
(void)std::transform(index_shape.begin(), index_shape.end(), std::back_inserter(index_shape_), LongToSize);
if (input_shape_.size() != index_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', shape size of 'x' must be equal to 'index', but got shape size of 'x': "
@ -78,20 +99,11 @@ void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
}
output_shape_ = index_shape_;
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, GatherDFunc> &pair) { return pair.first; });
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "GatherD does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return KRET_OK;
}
template <typename T, typename I>
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_);
@ -157,67 +169,71 @@ bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
return true;
}
std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::GatherDFunc>> GatherDCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> &GatherDCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod);
} // namespace kernel

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
@ -23,25 +24,27 @@
namespace mindspore {
namespace kernel {
class GatherDCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class GatherDCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GatherDCpuKernelMod> {
public:
GatherDCpuKernelMod() = default;
~GatherDCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
private:
template <typename T, typename I>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using GatherDFunc = std::function<bool(GatherDCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, GatherDFunc>> func_list_;
GatherDFunc kernel_func_;
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
std::vector<size_t> input_shape_;
std::vector<size_t> index_shape_;
std::vector<size_t> output_shape_;

View File

@ -15,92 +15,341 @@
*/
#include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h"
#include <string>
#include <algorithm>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherFwdGpuKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherFwdGpuKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherFwdGpuKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherFwdGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherFwdGpuKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherFwdGpuKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherFwdGpuKernelMod, int, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
GatherFwdGpuKernelMod, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherFwdGpuKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherFwdGpuKernelMod, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherFwdGpuKernelMod, int16_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherFwdGpuKernelMod, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherFwdGpuKernelMod, int64_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherFwdGpuKernelMod, int64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherFwdGpuKernelMod, uint, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherFwdGpuKernelMod, uint, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherFwdGpuKernelMod, uchar, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherFwdGpuKernelMod, uchar, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherFwdGpuKernelMod, bool, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherFwdGpuKernelMod, bool, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherFwdGpuKernelMod, uint32_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherFwdGpuKernelMod, uint32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
GatherFwdGpuKernelMod, uint64_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
GatherFwdGpuKernelMod, uint64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
GatherFwdGpuKernelMod, uint16_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
GatherFwdGpuKernelMod, uint16_t, int64_t)
std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
// For static shape case:
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
// For dynamic shape case:
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&GatherFwdGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&GatherFwdGpuKernelMod::LaunchKernel<float, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>}};
bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
int64_t x_rank = SizeToLong(input_shapes_.size());
if (dim_value < -x_rank || dim_value >= x_rank) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << x_rank << "," << x_rank
<< "), but got " << dim_value;
return false;
}
if (dim_value < 0) {
dim_value += x_rank;
}
size_t dim_before_axis = 1;
for (size_t i = 0; i < LongToSize(dim_value); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_input = input_shapes_[LongToSize(dim_value)];
size_t dim_at_axis_output = output_shapes_[LongToSize(dim_value)];
size_t dim_after_axis = 1;
for (size_t i = LongToSize(dim_value) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
const size_t k2Idx = 2;
const size_t k3Idx = 3;
dims_[0] = dim_before_axis;
dims_[1] = dim_at_axis_input;
dims_[k2Idx] = dim_at_axis_output;
dims_[k3Idx] = dim_after_axis;
return true;
}
bool GatherFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
size_t input_num = inputs.size();
const size_t kStaticInputNum = 2;
const size_t kDynInputNum = 3;
if (input_num == kStaticInputNum) {
is_dynamic_case_ = false;
} else if (input_num == kDynInputNum) {
is_dynamic_case_ = true;
const size_t kDynIndexIdx = 2;
index_idx_ = kDynIndexIdx;
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != KRET_OK) {
return ret;
}
auto input_shapes = inputs[0]->GetShapeVector();
auto index_shapes = inputs[index_idx_]->GetShapeVector();
auto output_shapes = outputs[0]->GetShapeVector();
input_shapes_.clear();
index_shapes_.clear();
output_shapes_.clear();
std::transform(input_shapes.cbegin(), input_shapes.cend(), std::back_inserter(input_shapes_), LongToSize);
std::transform(index_shapes.cbegin(), index_shapes.cend(), std::back_inserter(index_shapes_), LongToSize);
std::transform(output_shapes.cbegin(), output_shapes.cend(), std::back_inserter(output_shapes_), LongToSize);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
CHECK_SHAPE_NULL(index_shapes_, kernel_name_, "input_indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
if (is_null_input_) {
return KRET_OK;
}
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
<< "dimension of index: " << index_shapes_.size()
<< ", but got the dimension of input: " << input_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
return KRET_RESIZE_FAILED;
}
int64_t dim_value = 0;
if (!is_dynamic_case_) {
const std::string kAttrDim = "dim";
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
if (dim_attr == nullptr) {
return KRET_RESIZE_FAILED;
}
dim_value = GetValue<int64_t>(dim_attr);
} else {
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
}
if (!SetDimParam(dim_value)) {
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
std::vector<KernelAttr> GatherFwdGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, GatherFwdFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GatherD, GatherFwdGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -14,124 +14,83 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class GatherFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class GatherFwdGpuKernelMod : public NativeGpuKernelMod {
public:
GatherFwdGpuKernelMod() : axis_(0), is_null_input_(false) {}
GatherFwdGpuKernelMod() {}
~GatherFwdGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (!kernel_func_) {
MS_LOG(ERROR) << "GatherFwdGpu's kernel function is not initialized.";
return false;
}
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *index_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID);
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
}
input_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
index_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
CHECK_SHAPE_NULL(index_shapes_, kernel_name, "input_indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
<< "dimension of index: " << index_shapes_.size()
<< ", but got the dimension of input: " << input_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
}
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
Reshape();
InitSizeLists();
auto input_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
auto index_addr = reinterpret_cast<S *>(inputs.at(index_idx_)->addr);
auto output_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
const size_t k2Idx = 2;
const size_t k3Idx = 3;
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[k2Idx], dims_[k3Idx], cuda_stream,
GET_CTX_DEVICE_ID);
return true;
}
protected:
void InitResource() override {}
void InitSizeLists() override {
size_t size = GetSize(input_shapes_, true);
input_size_list_.push_back(size);
bool SetDimParam(int64_t dim_value);
size = GetSize(index_shapes_, false);
input_size_list_.push_back(size);
using GatherFwdFunc =
std::function<bool(GatherFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
size = GetSize(output_shapes_, true);
output_size_list_.push_back(size);
}
private:
void Reshape() {
size_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
size_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_at_axis_input;
dims_[2] = dim_at_axis_output;
dims_[3] = dim_after_axis;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
size_t result = flag ? sizeof(T) : sizeof(S);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
GatherFwdFunc kernel_func_;
static std::vector<std::pair<KernelAttr, GatherFwdFunc>> func_list_;
std::vector<size_t> input_shapes_;
std::vector<size_t> index_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[4] = {};
int axis_;
bool is_null_input_;
bool is_null_input_{false};
bool is_dynamic_case_{false};
size_t index_idx_{1};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_

View File

@ -385,10 +385,15 @@ std::optional<std::vector<int64_t>> GetDynamicAttrIntValue(
<< "th input in the depend_tensor_map";
}
auto input_tensor = depend_iter->second;
const auto &input_shape = inputs[input_index]->GetShapeVector();
auto input_shape = inputs[input_index]->GetShapeVector();
// The shape keep in depend_tensor_map was processed if it was empty.
if (input_shape.empty()) {
input_shape.push_back(1);
}
if (input_shape != input_tensor->shape()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index
<< "th input is different between the InferShape and the TensorShape";
<< "th input is different between the InferShape and the TensorShape: " << input_shape << " vs "
<< input_tensor->shape();
}
const auto &data_format = inputs[input_index]->GetFormat();
if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) {

View File

@ -69,6 +69,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
static const auto &kGather = prim::kPrimGather->name();
static const auto &kGatherV2 = prim::kPrimGatherV2->name();
static const auto &kGatherD = prim::kPrimGatherD->name();
static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name();
static const auto &kRange = prim::kPrimRange->name();
static const auto &kRangeV2 = prim::kPrimRangeV2->name();
@ -101,6 +102,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
{kMatrixSetDiagV3, ShapeSet{2}},
{kGather, ShapeSet{2}},
{kGatherV2, ShapeSet{2}},
{kGatherD, ShapeSet{1}},
{kSparseGatherV2, ShapeSet{2}},
{kRange, ShapeSet{0, 1, 2}},
{kRangeV2, ShapeSet{0, 1, 2}},

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,35 +26,97 @@ namespace mindspore {
namespace ops {
// gather_d
namespace {
int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
MS_EXCEPTION_IF_NULL(dim_ptr);
auto dim_value_ptr = dim_ptr->BuildValue();
MS_EXCEPTION_IF_NULL(dim_value_ptr);
auto dim_type_ptr = dim_ptr->BuildType();
MS_EXCEPTION_IF_NULL(dim_type_ptr);
int64_t dim_v = 0;
if (dim_value_ptr->isa<tensor::Tensor>()) {
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(dim_tensor);
size_t data_size = dim_tensor->DataSize();
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
auto dim_type_id = dim_type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(dim_type_id);
auto element = dim_type_id->element();
MS_EXCEPTION_IF_NULL(element);
if (element->type_id() == kNumberTypeInt32) {
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data32);
dim_v = static_cast<int64_t>(*dim_data32);
} else {
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data64);
dim_v = static_cast<int64_t>(*dim_data64);
}
} else {
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
dim_v = GetValue<int64_t>(dim_value_ptr);
} else {
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
}
}
return dim_v;
}
bool IsShapeInValid(const ShapeVector &shape) {
return std::any_of(shape.cbegin(), shape.cend(), [](int64_t s) { return s < 0; });
}
void CheckGatherShapeEqual(const std::string &prim_name, const ShapeVector &x_shape, int64_t dim_v,
const ShapeVector &index_shape) {
if (IsShapeInValid(x_shape) || IsShapeInValid(index_shape)) {
return;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (SizeToLong(i) == dim_v) continue;
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
}
}
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
const size_t gather_d_input_num = 3;
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == gather_d_input_num,
"GatherD's input size should be 3 but got " + std::to_string(input_args.size()));
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex0]->BuildShape()->isa<abstract::Shape>(), "x's shape wrong.");
auto shape_element = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
auto x_shape = shape_element->shape();
auto x_min_shape = shape_element->min_shape();
auto x_max_shape = shape_element->max_shape();
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex2]->BuildShape()->isa<abstract::Shape>(), "index's shape wrong.");
auto index_shape_element = input_args[kInputIndex2]->BuildShape()->cast<abstract::ShapePtr>();
auto index_shape = index_shape_element->shape();
auto index_min_shape = index_shape_element->min_shape();
auto index_max_shape = index_shape_element->max_shape();
int64_t x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name);
auto value_ptr = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(value_ptr);
auto dim_v = GetValue<int64_t>(value_ptr);
auto dim_v = GetGatherDimValue(input_args[kInputIndex1]);
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name);
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name);
if (dim_v < 0) {
dim_v = dim_v + x_rank;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (SizeToLong(i) == dim_v) continue;
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
}
return std::make_shared<abstract::Shape>(index_shape);
// For Ascend, only support x.shape[d] == index.shape[d] when d != dim. So limit it.
CheckGatherShapeEqual(prim_name, x_shape, dim_v, index_shape);
CheckGatherShapeEqual(prim_name, x_min_shape, dim_v, index_min_shape);
CheckGatherShapeEqual(prim_name, x_max_shape, dim_v, index_max_shape);
return std::make_shared<abstract::Shape>(index_shape, index_min_shape, index_max_shape);
}
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name);
return x_type;
@ -70,10 +132,12 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
}
auto prim_name = primitive->name();
// check
std::set<TypePtr> valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types,
std::set<TypePtr> index_valid_types = {kInt32, kInt64};
std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
std::make_shared<TensorType>(kInt64)};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);

View File

@ -21,7 +21,7 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
_handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis
from ..operations.array_ops import Fills
from ..operations.array_ops import UniqueConsecutive
@ -541,6 +541,48 @@ def get_ger_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.GatherD)
def get_gatherd_vmap_rule(prim, axis_size):
"""VmapRule for GatherD operations."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(x_bdim, dim_bdim, index_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
dim_value, axis_dim = dim_bdim
index, index_dim = index_bdim
# `dim` will be a Tensor in dynamic shape case, do not support its vamp.
if axis_dim is not None:
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
"but got {}.".format(axis_dim))
if not isinstance(dim_value, int):
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
out_dim = index_dim
# Broadcast if needed.
if x_dim is None:
x = _broadcast_by_axis(x, index_dim, axis_size)
elif index_dim is None:
index = _broadcast_by_axis(index, x_dim, axis_size)
out_dim = x_dim
elif x_dim != index_dim:
mnp.moveaxis(x, x_dim, index_dim)
# Adapt `dim` to vmap case.
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
out = prim(x, dim_value, index)
return (out, out_dim)
return vmap_rule
@vmap_rules_getters.register(P.SpaceToBatchND)
def get_space_to_batch_nd_vmap_rule(prim, axis_size):
"""VmapRule for `SpaceToBatchND`."""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,14 +13,20 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.functional import vmap
context.set_context(device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, dim=0):
super(Net, self).__init__()
@ -40,26 +46,89 @@ class NetGrad(nn.Cell):
return self.op(index, x)
def test_net():
def get_data(ms_type):
x = Tensor(np.array([[772, 231, 508, 545, 615, 249],
[923, 210, 480, 696, 482, 761],
[465, 904, 521, 824, 607, 669],
[156, 539, 56, 159, 916, 566],
[122, 676, 714, 261, 19, 936]]), mindspore.int32)
index = Tensor(np.array([[0, 0, 0, 1, 1],
[0, 0, 0, 1, 4],
[0, 0, 0, 1, -1],
[1, 1, 1, 0, 0]]), mindspore.int32)
[122, 676, 714, 261, 19, 936]]), ms_type)
dim = 0
index = Tensor(np.array([[0, 1, 0, 1, 0, -4],
[0, 2, 0, 2, 0, -3],
[0, 0, 0, 3, 3, -2],
[4, 4, 4, 0, 0, -1],
[4, 3, 2, 1, -1, -2]]), mindspore.int32)
expect = np.array([[772, 210, 508, 696, 615, 761],
[772, 904, 508, 824, 615, 669],
[772, 231, 508, 159, 916, 566],
[122, 676, 714, 545, 615, 936],
[122, 539, 521, 696, 19, 566]])
res = (x, dim, index, expect)
return res
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
def test_net(ms_type):
"""
Feature: test GatherD static shape.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE)
x, dim, index, expect = get_data(ms_type)
net = Net(dim)
out = net(x, index)
print(out.asnumpy())
expect_out = np.array([[772, 231, 508, 696, 482],
[772, 231, 508, 696, 19],
[772, 231, 508, 696, 19],
[923, 210, 480, 545, 615]])
assert np.array_equal(out.asnumpy(), expect_out)
assert np.array_equal(out.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
def test_gatherd_dynamic(ms_type):
"""
Feature: test GatherD dynamic shape.
Description: index is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE)
x, dim, index, expect = get_data(ms_type)
index_dyn = Tensor(shape=[index.shape[0], None], dtype=mindspore.int32)
net = Net(dim)
net.set_inputs(x, index_dyn)
out = net(x, index)
assert np.array_equal(out.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.int32, np.uint32, np.float32])
def test_gatherd_vmap(dtype):
"""
Feature: test GatherD vmap interface.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE)
def cal_gatherd(x, dim, index):
return P.GatherD()(x, dim, index)
gather_dim = 1
x = Tensor(np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]).astype(dtype))
y = Tensor(np.array([[[0, 0], [1, 0]], [[0, 0], [1, 0]], [[0, 0], [1, 0]]]).astype(np.int32))
outputs = vmap(cal_gatherd, in_axes=(0, None, 0), out_axes=0)(x, gather_dim, y)
expect = np.array([[[1, 1], [4, 3]], [[1, 1], [4, 3]], [[1, 1], [4, 3]]]).astype(dtype)
assert np.allclose(outputs.asnumpy(), expect)
def test_net_bool():