add dynamic and vmap support for gatherd

This commit is contained in:
tronzhang 2022-05-26 14:35:12 +08:00
parent 515e54d361
commit 83280c2230
10 changed files with 702 additions and 292 deletions

View File

@ -82,6 +82,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
} }
return true; return true;
} }
bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, bool ArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) { const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator); auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMax>(base_operator);

View File

@ -66,11 +66,32 @@ void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, const I *index, co
} }
} // namespace } // namespace
void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { bool GatherDCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
MS_EXCEPTION_IF_NULL(kernel_node); const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); kernel_name_ = base_operator->name();
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); if (inputs.size() != 3) {
index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2); MS_LOG(ERROR) << "GatherD input size must be equal to 3!";
return false;
}
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
return ret;
}
return true;
}
int GatherDCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
const size_t kIndexIdx = 2;
auto input_shape = inputs[0]->GetShapeVector();
auto index_shape = inputs[kIndexIdx]->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
(void)std::transform(index_shape.begin(), index_shape.end(), std::back_inserter(index_shape_), LongToSize);
if (input_shape_.size() != index_shape_.size()) { if (input_shape_.size() != index_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', shape size of 'x' must be equal to 'index', but got shape size of 'x': " << "', shape size of 'x' must be equal to 'index', but got shape size of 'x': "
@ -78,20 +99,11 @@ void GatherDCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
} }
output_shape_ = index_shape_; output_shape_ = index_shape_;
std::vector<KernelAttr> support_list; return KRET_OK;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, GatherDFunc> &pair) { return pair.first; });
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "GatherD does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
} }
template <typename T, typename I> template <typename T, typename I>
bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_); CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGatherDInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGatherDOutputsNum, kernel_name_);
@ -157,67 +169,71 @@ bool GatherDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
return true; return true;
} }
std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::GatherDFunc>> GatherDCpuKernelMod::func_list_ = { const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> &GatherDCpuKernelMod::GetFuncList() const {
{KernelAttr() static const std::vector<std::pair<KernelAttr, GatherDCpuKernelMod::KernelRunFunc>> func_list = {
.AddInputAttr(kNumberTypeFloat32) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32), .AddInputAttr(kNumberTypeInt32)
&GatherDCpuKernelMod::LaunchKernel<float, int32_t>}, .AddOutputAttr(kNumberTypeFloat32),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<float, int32_t>},
.AddInputAttr(kNumberTypeFloat32) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32), .AddInputAttr(kNumberTypeInt64)
&GatherDCpuKernelMod::LaunchKernel<float, int64_t>}, .AddOutputAttr(kNumberTypeFloat32),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<float, int64_t>},
.AddInputAttr(kNumberTypeFloat16) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16), .AddInputAttr(kNumberTypeInt32)
&GatherDCpuKernelMod::LaunchKernel<float16, int32_t>}, .AddOutputAttr(kNumberTypeFloat16),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<float16, int32_t>},
.AddInputAttr(kNumberTypeFloat16) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16), .AddInputAttr(kNumberTypeInt64)
&GatherDCpuKernelMod::LaunchKernel<float16, int64_t>}, .AddOutputAttr(kNumberTypeFloat16),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<float16, int64_t>},
.AddInputAttr(kNumberTypeInt32) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddInputAttr(kNumberTypeInt32)
&GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>}, .AddOutputAttr(kNumberTypeInt32),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<int32_t, int32_t>},
.AddInputAttr(kNumberTypeInt32) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddInputAttr(kNumberTypeInt64)
&GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>}, .AddOutputAttr(kNumberTypeInt32),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<int32_t, int64_t>},
.AddInputAttr(kNumberTypeInt64) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64), .AddInputAttr(kNumberTypeInt32)
&GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>}, .AddOutputAttr(kNumberTypeInt64),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<int64_t, int32_t>},
.AddInputAttr(kNumberTypeInt64) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64), .AddInputAttr(kNumberTypeInt64)
&GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>}, .AddOutputAttr(kNumberTypeInt64),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<int64_t, int64_t>},
.AddInputAttr(kNumberTypeBool) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool), .AddInputAttr(kNumberTypeInt32)
&GatherDCpuKernelMod::LaunchKernel<bool, int32_t>}, .AddOutputAttr(kNumberTypeBool),
{KernelAttr() &GatherDCpuKernelMod::LaunchKernel<bool, int32_t>},
.AddInputAttr(kNumberTypeBool) {KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool), .AddInputAttr(kNumberTypeInt64)
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}}; .AddOutputAttr(kNumberTypeBool),
&GatherDCpuKernelMod::LaunchKernel<bool, int64_t>}};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod); MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GatherD, GatherDCpuKernelMod);
} // namespace kernel } // namespace kernel

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHER_D_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/device/cpu/kernel/cpu_kernel.h"
@ -23,25 +24,27 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class GatherDCpuKernelMod : public DeprecatedNativeCpuKernelMod { class GatherDCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GatherDCpuKernelMod> {
public: public:
GatherDCpuKernelMod() = default; GatherDCpuKernelMod() = default;
~GatherDCpuKernelMod() override = default; ~GatherDCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override; bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override { const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs); return kernel_func_(this, inputs, workspace, outputs);
} }
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
private: private:
template <typename T, typename I> template <typename T, typename I>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
using GatherDFunc = std::function<bool(GatherDCpuKernelMod *, const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &outputs);
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, GatherDFunc>> func_list_;
GatherDFunc kernel_func_;
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> index_shape_; std::vector<size_t> index_shape_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;

View File

@ -15,92 +15,341 @@
*/ */
#include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h" #include "plugin/device/gpu/kernel/arrays/gather_gpu_kernel.h"
#include <string>
#include <algorithm>
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_TWO( std::vector<std::pair<KernelAttr, GatherFwdGpuKernelMod::GatherFwdFunc>> GatherFwdGpuKernelMod::func_list_ = {
GatherD, // For static shape case:
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherFwdGpuKernelMod, double, int) &GatherFwdGpuKernelMod::LaunchKernel<double, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherD, &GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherFwdGpuKernelMod, double, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<float, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherD, &GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherFwdGpuKernelMod, float, int) &GatherFwdGpuKernelMod::LaunchKernel<half, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherD, &GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherFwdGpuKernelMod, float, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<int, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
GatherD, &GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherFwdGpuKernelMod, half, int) &GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherD, &GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherFwdGpuKernelMod, half, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
GatherFwdGpuKernelMod, int, int) {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherFwdGpuKernelMod, int, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), &GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
GatherFwdGpuKernelMod, int8_t, int) {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherFwdGpuKernelMod, int8_t, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), &GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
GatherFwdGpuKernelMod, int16_t, int) {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherFwdGpuKernelMod, int16_t, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), &GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
GatherFwdGpuKernelMod, int64_t, int) {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
GatherFwdGpuKernelMod, int64_t, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
MS_REG_GPU_KERNEL_TWO( {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), &GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
GatherFwdGpuKernelMod, uint, int) {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
GatherFwdGpuKernelMod, uint, int64_t) &GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
MS_REG_GPU_KERNEL_TWO( // For dynamic shape case:
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), {KernelAttr()
GatherFwdGpuKernelMod, uchar, int) .AddInputAttr(kNumberTypeFloat64)
MS_REG_GPU_KERNEL_TWO( .AddInputAttr(kNumberTypeInt64)
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), .AddInputAttr(kNumberTypeInt32)
GatherFwdGpuKernelMod, uchar, int64_t) .AddOutputAttr(kNumberTypeFloat64),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<double, int>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), {KernelAttr()
GatherFwdGpuKernelMod, bool, int) .AddInputAttr(kNumberTypeFloat64)
MS_REG_GPU_KERNEL_TWO( .AddInputAttr(kNumberTypeInt64)
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), .AddInputAttr(kNumberTypeInt64)
GatherFwdGpuKernelMod, bool, int64_t) .AddOutputAttr(kNumberTypeFloat64),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<double, int64_t>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), {KernelAttr()
GatherFwdGpuKernelMod, uint32_t, int) .AddInputAttr(kNumberTypeFloat32)
MS_REG_GPU_KERNEL_TWO( .AddInputAttr(kNumberTypeInt64)
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), .AddInputAttr(kNumberTypeInt32)
GatherFwdGpuKernelMod, uint32_t, int64_t) .AddOutputAttr(kNumberTypeFloat32),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<float, int>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), {KernelAttr()
GatherFwdGpuKernelMod, uint64_t, int) .AddInputAttr(kNumberTypeFloat32)
MS_REG_GPU_KERNEL_TWO( .AddInputAttr(kNumberTypeInt64)
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), .AddInputAttr(kNumberTypeInt64)
GatherFwdGpuKernelMod, uint64_t, int64_t) .AddOutputAttr(kNumberTypeFloat32),
MS_REG_GPU_KERNEL_TWO( &GatherFwdGpuKernelMod::LaunchKernel<float, int64_t>},
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), {KernelAttr()
GatherFwdGpuKernelMod, uint16_t, int) .AddInputAttr(kNumberTypeFloat16)
MS_REG_GPU_KERNEL_TWO( .AddInputAttr(kNumberTypeInt64)
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), .AddInputAttr(kNumberTypeInt32)
GatherFwdGpuKernelMod, uint16_t, int64_t) .AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&GatherFwdGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&GatherFwdGpuKernelMod::LaunchKernel<int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&GatherFwdGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&GatherFwdGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&GatherFwdGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&GatherFwdGpuKernelMod::LaunchKernel<uchar, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
&GatherFwdGpuKernelMod::LaunchKernel<bool, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&GatherFwdGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
&GatherFwdGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&GatherFwdGpuKernelMod::LaunchKernel<uint16_t, int64_t>}};
bool GatherFwdGpuKernelMod::SetDimParam(int64_t dim_value) {
int64_t x_rank = SizeToLong(input_shapes_.size());
if (dim_value < -x_rank || dim_value >= x_rank) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << x_rank << "," << x_rank
<< "), but got " << dim_value;
return false;
}
if (dim_value < 0) {
dim_value += x_rank;
}
size_t dim_before_axis = 1;
for (size_t i = 0; i < LongToSize(dim_value); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_input = input_shapes_[LongToSize(dim_value)];
size_t dim_at_axis_output = output_shapes_[LongToSize(dim_value)];
size_t dim_after_axis = 1;
for (size_t i = LongToSize(dim_value) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
const size_t k2Idx = 2;
const size_t k3Idx = 3;
dims_[0] = dim_before_axis;
dims_[1] = dim_at_axis_input;
dims_[k2Idx] = dim_at_axis_output;
dims_[k3Idx] = dim_after_axis;
return true;
}
bool GatherFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
size_t input_num = inputs.size();
const size_t kStaticInputNum = 2;
const size_t kDynInputNum = 3;
if (input_num == kStaticInputNum) {
is_dynamic_case_ = false;
} else if (input_num == kDynInputNum) {
is_dynamic_case_ = true;
const size_t kDynIndexIdx = 2;
index_idx_ = kDynIndexIdx;
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int GatherFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != KRET_OK) {
return ret;
}
auto input_shapes = inputs[0]->GetShapeVector();
auto index_shapes = inputs[index_idx_]->GetShapeVector();
auto output_shapes = outputs[0]->GetShapeVector();
input_shapes_.clear();
index_shapes_.clear();
output_shapes_.clear();
std::transform(input_shapes.cbegin(), input_shapes.cend(), std::back_inserter(input_shapes_), LongToSize);
std::transform(index_shapes.cbegin(), index_shapes.cend(), std::back_inserter(index_shapes_), LongToSize);
std::transform(output_shapes.cbegin(), output_shapes.cend(), std::back_inserter(output_shapes_), LongToSize);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
CHECK_SHAPE_NULL(index_shapes_, kernel_name_, "input_indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
if (is_null_input_) {
return KRET_OK;
}
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
<< "dimension of index: " << index_shapes_.size()
<< ", but got the dimension of input: " << input_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
return KRET_RESIZE_FAILED;
}
int64_t dim_value = 0;
if (!is_dynamic_case_) {
const std::string kAttrDim = "dim";
auto dim_attr = base_operator->GetPrim()->GetAttr(kAttrDim);
if (dim_attr == nullptr) {
return KRET_RESIZE_FAILED;
}
dim_value = GetValue<int64_t>(dim_attr);
} else {
GetDynamicAttrIntValue(inputs, 1, inputsOnHost, kernel_name_, &dim_value);
}
if (!SetDimParam(dim_value)) {
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
std::vector<KernelAttr> GatherFwdGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, GatherFwdFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GatherD, GatherFwdGpuKernelMod);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -14,124 +14,83 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T, typename S> class GatherFwdGpuKernelMod : public NativeGpuKernelMod {
class GatherFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public: public:
GatherFwdGpuKernelMod() : axis_(0), is_null_input_(false) {} GatherFwdGpuKernelMod() {}
~GatherFwdGpuKernelMod() = default; ~GatherFwdGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (!kernel_func_) {
MS_LOG(ERROR) << "GatherFwdGpu's kernel function is not initialized.";
return false;
}
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) { if (is_null_input_) {
return true; return true;
} }
VARIABLE_NOT_USED(workspace); VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *index_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3], auto input_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID); auto index_addr = reinterpret_cast<S *>(inputs.at(index_idx_)->addr);
return true; auto output_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
} auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); const size_t k2Idx = 2;
kernel_node_ = kernel_node; const size_t k3Idx = 3;
InitResource(); Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[k2Idx], dims_[k3Idx], cuda_stream,
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); GET_CTX_DEVICE_ID);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
}
input_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
index_shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
CHECK_SHAPE_NULL(index_shapes_, kernel_name, "input_indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input and output must be the equal to the "
<< "dimension of index: " << index_shapes_.size()
<< ", but got the dimension of input: " << input_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
}
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
Reshape();
InitSizeLists();
return true; return true;
} }
protected: bool SetDimParam(int64_t dim_value);
void InitResource() override {}
void InitSizeLists() override {
size_t size = GetSize(input_shapes_, true);
input_size_list_.push_back(size);
size = GetSize(index_shapes_, false); using GatherFwdFunc =
input_size_list_.push_back(size); std::function<bool(GatherFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
size = GetSize(output_shapes_, true); GatherFwdFunc kernel_func_;
output_size_list_.push_back(size); static std::vector<std::pair<KernelAttr, GatherFwdFunc>> func_list_;
}
private:
void Reshape() {
size_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
size_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_at_axis_input;
dims_[2] = dim_at_axis_output;
dims_[3] = dim_after_axis;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
size_t result = flag ? sizeof(T) : sizeof(S);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
std::vector<size_t> input_shapes_; std::vector<size_t> input_shapes_;
std::vector<size_t> index_shapes_; std::vector<size_t> index_shapes_;
std::vector<size_t> output_shapes_; std::vector<size_t> output_shapes_;
size_t dims_[4] = {}; size_t dims_[4] = {};
int axis_; bool is_null_input_{false};
bool is_null_input_; bool is_dynamic_case_{false};
size_t index_idx_{1};
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHER_GPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_GATHER_GPU_KERNEL_H_

View File

@ -385,10 +385,15 @@ std::optional<std::vector<int64_t>> GetDynamicAttrIntValue(
<< "th input in the depend_tensor_map"; << "th input in the depend_tensor_map";
} }
auto input_tensor = depend_iter->second; auto input_tensor = depend_iter->second;
const auto &input_shape = inputs[input_index]->GetShapeVector(); auto input_shape = inputs[input_index]->GetShapeVector();
// The shape keep in depend_tensor_map was processed if it was empty.
if (input_shape.empty()) {
input_shape.push_back(1);
}
if (input_shape != input_tensor->shape()) { if (input_shape != input_tensor->shape()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the " << input_index
<< "th input is different between the InferShape and the TensorShape"; << "th input is different between the InferShape and the TensorShape: " << input_shape << " vs "
<< input_tensor->shape();
} }
const auto &data_format = inputs[input_index]->GetFormat(); const auto &data_format = inputs[input_index]->GetFormat();
if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) { if (data_format != mindspore::Format::DEFAULT_FORMAT && data_format != mindspore::Format::NCHW) {

View File

@ -69,6 +69,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name(); static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
static const auto &kGather = prim::kPrimGather->name(); static const auto &kGather = prim::kPrimGather->name();
static const auto &kGatherV2 = prim::kPrimGatherV2->name(); static const auto &kGatherV2 = prim::kPrimGatherV2->name();
static const auto &kGatherD = prim::kPrimGatherD->name();
static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name(); static const auto &kSparseGatherV2 = prim::kPrimSparseGatherV2->name();
static const auto &kRange = prim::kPrimRange->name(); static const auto &kRange = prim::kPrimRange->name();
static const auto &kRangeV2 = prim::kPrimRangeV2->name(); static const auto &kRangeV2 = prim::kPrimRangeV2->name();
@ -101,6 +102,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
{kMatrixSetDiagV3, ShapeSet{2}}, {kMatrixSetDiagV3, ShapeSet{2}},
{kGather, ShapeSet{2}}, {kGather, ShapeSet{2}},
{kGatherV2, ShapeSet{2}}, {kGatherV2, ShapeSet{2}},
{kGatherD, ShapeSet{1}},
{kSparseGatherV2, ShapeSet{2}}, {kSparseGatherV2, ShapeSet{2}},
{kRange, ShapeSet{0, 1, 2}}, {kRange, ShapeSet{0, 1, 2}},
{kRangeV2, ShapeSet{0, 1, 2}}, {kRangeV2, ShapeSet{0, 1, 2}},

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,35 +26,97 @@ namespace mindspore {
namespace ops { namespace ops {
// gather_d // gather_d
namespace { namespace {
int64_t GetGatherDimValue(const AbstractBasePtr dim_ptr) {
MS_EXCEPTION_IF_NULL(dim_ptr);
auto dim_value_ptr = dim_ptr->BuildValue();
MS_EXCEPTION_IF_NULL(dim_value_ptr);
auto dim_type_ptr = dim_ptr->BuildType();
MS_EXCEPTION_IF_NULL(dim_type_ptr);
int64_t dim_v = 0;
if (dim_value_ptr->isa<tensor::Tensor>()) {
auto dim_tensor = dim_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(dim_tensor);
size_t data_size = dim_tensor->DataSize();
MS_EXCEPTION_IF_CHECK_FAIL(data_size == 1, "dim value is not equal to one!");
auto dim_type_id = dim_type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(dim_type_id);
auto element = dim_type_id->element();
MS_EXCEPTION_IF_NULL(element);
if (element->type_id() == kNumberTypeInt32) {
auto dim_data32 = reinterpret_cast<int *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data32);
dim_v = static_cast<int64_t>(*dim_data32);
} else {
auto dim_data64 = reinterpret_cast<int64_t *>(dim_tensor->data_c());
MS_EXCEPTION_IF_NULL(dim_data64);
dim_v = static_cast<int64_t>(*dim_data64);
}
} else {
if (dim_value_ptr->isa<Int32Imm>() || dim_value_ptr->isa<Int64Imm>()) {
dim_v = GetValue<int64_t>(dim_value_ptr);
} else {
MS_LOG(EXCEPTION) << "For GatherD, 'dim' must be one of these types: [int32/int64].";
}
}
return dim_v;
}
bool IsShapeInValid(const ShapeVector &shape) {
return std::any_of(shape.cbegin(), shape.cend(), [](int64_t s) { return s < 0; });
}
void CheckGatherShapeEqual(const std::string &prim_name, const ShapeVector &x_shape, int64_t dim_v,
const ShapeVector &index_shape) {
if (IsShapeInValid(x_shape) || IsShapeInValid(index_shape)) {
return;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (SizeToLong(i) == dim_v) continue;
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape.";
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name);
}
}
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; const size_t gather_d_input_num = 3;
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == gather_d_input_num,
"GatherD's input size should be 3 but got " + std::to_string(input_args.size()));
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex0]->BuildShape()->isa<abstract::Shape>(), "x's shape wrong.");
auto shape_element = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
auto x_shape = shape_element->shape();
auto x_min_shape = shape_element->min_shape();
auto x_max_shape = shape_element->max_shape();
MS_EXCEPTION_IF_CHECK_FAIL(input_args[kInputIndex2]->BuildShape()->isa<abstract::Shape>(), "index's shape wrong.");
auto index_shape_element = input_args[kInputIndex2]->BuildShape()->cast<abstract::ShapePtr>();
auto index_shape = index_shape_element->shape();
auto index_min_shape = index_shape_element->min_shape();
auto index_max_shape = index_shape_element->max_shape();
int64_t x_rank = SizeToLong(x_shape.size()); int64_t x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name); CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, SizeToLong(index_shape.size()), prim_name);
auto value_ptr = input_args[1]->BuildValue(); auto dim_v = GetGatherDimValue(input_args[kInputIndex1]);
MS_EXCEPTION_IF_NULL(value_ptr);
auto dim_v = GetValue<int64_t>(value_ptr);
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name); CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, -x_rank, prim_name);
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name); CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, x_rank, prim_name);
if (dim_v < 0) { if (dim_v < 0) {
dim_v = dim_v + x_rank; dim_v = dim_v + x_rank;
} }
for (size_t i = 0; i < x_shape.size(); ++i) {
if (SizeToLong(i) == dim_v) continue; // For Ascend, only support x.shape[d] == index.shape[d] when d != dim. So limit it.
MS_LOG(INFO) << "For '" << prim_name << "', it's now checking " << i << "th x shape."; CheckGatherShapeEqual(prim_name, x_shape, dim_v, index_shape);
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, index_shape[i], prim_name); CheckGatherShapeEqual(prim_name, x_min_shape, dim_v, index_min_shape);
} CheckGatherShapeEqual(prim_name, x_max_shape, dim_v, index_max_shape);
return std::make_shared<abstract::Shape>(index_shape); return std::make_shared<abstract::Shape>(index_shape, index_min_shape, index_max_shape);
} }
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name(); auto prim_name = prim->name();
// check
std::set<TypePtr> valid_x_type = {kTensorType}; std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name); auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_x_type, prim_name);
return x_type; return x_type;
@ -70,10 +132,12 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
} }
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
std::set<TypePtr> valid_types = {kInt32, kInt64}; std::set<TypePtr> index_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), valid_types, std::set<TypePtr> dim_valid_types = {kInt32, kInt64, std::make_shared<TensorType>(kInt32),
std::make_shared<TensorType>(kInt64)};
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", input_args[kInputIndex2]->BuildType(), index_valid_types,
prim_name); prim_name);
(void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), valid_types, prim_name); (void)CheckAndConvertUtils::CheckSubClass("dim", input_args[kInputIndex1]->BuildType(), dim_valid_types, prim_name);
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args)); return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
} }
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true); REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);

View File

@ -21,7 +21,7 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr from mindspore.ops import constexpr
from ..primitive import Primitive from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \ from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
_handle_broadcasting, get_unsupported_dynamic_vmap_rule _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis
from ..operations.array_ops import Fills from ..operations.array_ops import Fills
from ..operations.array_ops import UniqueConsecutive from ..operations.array_ops import UniqueConsecutive
@ -541,6 +541,48 @@ def get_ger_vmap_rule(prim, axis_size):
return vmap_rule return vmap_rule
@vmap_rules_getters.register(P.GatherD)
def get_gatherd_vmap_rule(prim, axis_size):
"""VmapRule for GatherD operations."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(x_bdim, dim_bdim, index_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
dim_value, axis_dim = dim_bdim
index, index_dim = index_bdim
# `dim` will be a Tensor in dynamic shape case, do not support its vamp.
if axis_dim is not None:
_raise_value_error("The source axis of `dim` in `GatherD` must be None, "
"but got {}.".format(axis_dim))
if not isinstance(dim_value, int):
_raise_value_error("The `dim` in `GatherD` must be a const, but got {}.".format(dim_value))
out_dim = index_dim
# Broadcast if needed.
if x_dim is None:
x = _broadcast_by_axis(x, index_dim, axis_size)
elif index_dim is None:
index = _broadcast_by_axis(index, x_dim, axis_size)
out_dim = x_dim
elif x_dim != index_dim:
mnp.moveaxis(x, x_dim, index_dim)
# Adapt `dim` to vmap case.
dim_value = dim_value + 1 if dim_value >= out_dim else dim_value
out = prim(x, dim_value, index)
return (out, out_dim)
return vmap_rule
@vmap_rules_getters.register(P.SpaceToBatchND) @vmap_rules_getters.register(P.SpaceToBatchND)
def get_space_to_batch_nd_vmap_rule(prim, axis_size): def get_space_to_batch_nd_vmap_rule(prim, axis_size):
"""VmapRule for `SpaceToBatchND`.""" """VmapRule for `SpaceToBatchND`."""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,14 +13,20 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore import mindspore
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.functional import vmap
context.set_context(device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, dim=0): def __init__(self, dim=0):
super(Net, self).__init__() super(Net, self).__init__()
@ -40,26 +46,89 @@ class NetGrad(nn.Cell):
return self.op(index, x) return self.op(index, x)
def test_net(): def get_data(ms_type):
x = Tensor(np.array([[772, 231, 508, 545, 615, 249], x = Tensor(np.array([[772, 231, 508, 545, 615, 249],
[923, 210, 480, 696, 482, 761], [923, 210, 480, 696, 482, 761],
[465, 904, 521, 824, 607, 669], [465, 904, 521, 824, 607, 669],
[156, 539, 56, 159, 916, 566], [156, 539, 56, 159, 916, 566],
[122, 676, 714, 261, 19, 936]]), mindspore.int32) [122, 676, 714, 261, 19, 936]]), ms_type)
index = Tensor(np.array([[0, 0, 0, 1, 1],
[0, 0, 0, 1, 4],
[0, 0, 0, 1, -1],
[1, 1, 1, 0, 0]]), mindspore.int32)
dim = 0 dim = 0
index = Tensor(np.array([[0, 1, 0, 1, 0, -4],
[0, 2, 0, 2, 0, -3],
[0, 0, 0, 3, 3, -2],
[4, 4, 4, 0, 0, -1],
[4, 3, 2, 1, -1, -2]]), mindspore.int32)
expect = np.array([[772, 210, 508, 696, 615, 761],
[772, 904, 508, 824, 615, 669],
[772, 231, 508, 159, 916, 566],
[122, 676, 714, 545, 615, 936],
[122, 539, 521, 696, 19, 566]])
res = (x, dim, index, expect)
return res
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
def test_net(ms_type):
"""
Feature: test GatherD static shape.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE)
x, dim, index, expect = get_data(ms_type)
net = Net(dim) net = Net(dim)
out = net(x, index) out = net(x, index)
print(out.asnumpy())
expect_out = np.array([[772, 231, 508, 696, 482], assert np.array_equal(out.asnumpy(), expect)
[772, 231, 508, 696, 19],
[772, 231, 508, 696, 19],
[923, 210, 480, 545, 615]]) @pytest.mark.level0
assert np.array_equal(out.asnumpy(), expect_out) @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('ms_type', [mindspore.int32, mindspore.uint32, mindspore.float32])
def test_gatherd_dynamic(ms_type):
"""
Feature: test GatherD dynamic shape.
Description: index is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE)
x, dim, index, expect = get_data(ms_type)
index_dyn = Tensor(shape=[index.shape[0], None], dtype=mindspore.int32)
net = Net(dim)
net.set_inputs(x, index_dyn)
out = net(x, index)
assert np.array_equal(out.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.int32, np.uint32, np.float32])
def test_gatherd_vmap(dtype):
"""
Feature: test GatherD vmap interface.
Description: input x and index is static shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE)
def cal_gatherd(x, dim, index):
return P.GatherD()(x, dim, index)
gather_dim = 1
x = Tensor(np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]).astype(dtype))
y = Tensor(np.array([[[0, 0], [1, 0]], [[0, 0], [1, 0]], [[0, 0], [1, 0]]]).astype(np.int32))
outputs = vmap(cal_gatherd, in_axes=(0, None, 0), out_axes=0)(x, gather_dim, y)
expect = np.array([[[1, 1], [4, 3]], [[1, 1], [4, 3]], [[1, 1], [4, 3]]]).astype(dtype)
assert np.allclose(outputs.asnumpy(), expect)
def test_net_bool(): def test_net_bool():