!41395 gather_ops_support_dynamic_shape
Merge pull request !41395 from yao_yf/gather_ops_support_dynamic_shape
This commit is contained in:
commit
adf8ec5fb8
|
@ -32,41 +32,57 @@ constexpr size_t kGatherInputsNum = 2;
|
|||
constexpr size_t kGatherOutputsNum = 1;
|
||||
constexpr size_t kGatherInputParamsMaxDim = 7;
|
||||
} // namespace
|
||||
void GatherCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
bool GatherCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num == kGatherInputsNum + 1) {
|
||||
is_dynamic_shape_ = true;
|
||||
MS_LOG(DEBUG) << " GatherCPUKernel running in Dynamic Mode.";
|
||||
} else if (input_num == kGatherInputsNum) {
|
||||
axis_ = GetValue<int64_t>(base_operator->GetAttr("axis"));
|
||||
MS_LOG(DEBUG) << " GatherCPUKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherCPUKernel needs 2.";
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
indices_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
|
||||
if (is_dynamic_shape_) {
|
||||
axis_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GatherCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
CheckParam(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
indices_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
int GatherCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
ResetResource();
|
||||
input_shape_ = inputs[kIndexZero]->GetShapeVector();
|
||||
indices_shape_ = inputs[kIndexOne]->GetShapeVector();
|
||||
output_shape_ = outputs[kIndexZero]->GetShapeVector();
|
||||
is_null_input_ = input_shape_.empty() || indices_shape_.empty() || output_shape_.empty();
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
if (input_shape_.size() > kGatherInputParamsMaxDim) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_params' should be "
|
||||
<< kGatherInputParamsMaxDim << "D or lower, but got " << input_shape_.size() << ".";
|
||||
}
|
||||
if (!is_dynamic_shape_) {
|
||||
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
auto it_x = std::find_if(input_shape_.begin(), input_shape_.end(), [](int64_t sh) { return sh <= 0; });
|
||||
auto it_y = std::find_if(indices_shape_.begin(), indices_shape_.end(), [](int64_t sh) { return sh <= 0; });
|
||||
if (it_x != input_shape_.end() || it_y != indices_shape_.end()) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Gather does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
node_wpt_ = kernel_node;
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -76,19 +92,9 @@ bool GatherCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
|
||||
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||
auto *output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
|
||||
if (!node_wpt_.expired()) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
}
|
||||
if (inputs.size() == kGatherInputsNum) {
|
||||
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(node, AXIS);
|
||||
} else if (inputs.size() == kGatherInputsNum + 1) {
|
||||
axis_ = reinterpret_cast<int64_t *>(inputs[kIndex2]->addr)[0];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Gather requires " << kGatherInputsNum << " or " << (kGatherInputsNum + 1)
|
||||
<< " inputs, but got " << inputs.size();
|
||||
}
|
||||
|
||||
if (inputs.size() == kGatherInputsNum + 1) {
|
||||
axis_ = reinterpret_cast<int64_t *>(inputs[kIndex2]->addr)[0];
|
||||
}
|
||||
|
||||
int dims = SizeToInt(input_shape_.size());
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
@ -26,14 +27,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class GatherCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
GatherCpuKernelMod() = default;
|
||||
~GatherCpuKernelMod() override = default;
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
|
@ -42,6 +46,28 @@ class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void ResetResource() noexcept {
|
||||
input_shape_.clear();
|
||||
indices_shape_.clear();
|
||||
output_shape_.clear();
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void InitSizeLists() {
|
||||
auto input_size = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies{});
|
||||
auto indices_size = std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies{});
|
||||
input_size_list_.push_back(LongToSize(input_size) * input_type_size_);
|
||||
input_size_list_.push_back(LongToSize(indices_size) * indices_type_size_);
|
||||
if (is_dynamic_shape_) {
|
||||
input_size_list_.push_back(axis_type_size_);
|
||||
}
|
||||
auto output_size = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies{});
|
||||
output_size_list_.push_back(LongToSize(output_size) * input_type_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
@ -55,7 +81,10 @@ class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
ShapeVector output_shape_;
|
||||
int64_t axis_{0};
|
||||
bool is_dynamic_shape_{false};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
size_t input_type_size_ = 0;
|
||||
size_t indices_type_size_ = 0;
|
||||
size_t axis_type_size_ = 0;
|
||||
bool is_null_input_ = false;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,178 +19,222 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kStaticInputNum = 2;
|
||||
const size_t kDynInputNum = 3;
|
||||
constexpr char GATHER[] = "Gather";
|
||||
constexpr char GATHERV2[] = "GatherV2";
|
||||
constexpr char SPARSEGATHERV2[] = "SPARSEGATHERV2";
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num == kDynInputNum) {
|
||||
is_dynamic_shape_ = true;
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
|
||||
} else if (input_num == kStaticInputNum) {
|
||||
axis_ = static_cast<int>(GetValue<int64_t>(base_operator->GetAttr("axis")));
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
GatherV2FwdGpuKernelMod, Complex<float>, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
GatherV2FwdGpuKernelMod, Complex<float>, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
|
||||
GatherV2FwdGpuKernelMod, Complex<double>, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
GatherV2FwdGpuKernelMod, Complex<double>, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2FwdGpuKernelMod, int, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2FwdGpuKernelMod, int, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2FwdGpuKernelMod, int16_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2FwdGpuKernelMod, int16_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2FwdGpuKernelMod, int8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2FwdGpuKernelMod, int8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2FwdGpuKernelMod, uint, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2FwdGpuKernelMod, uint, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
|
||||
// dynamic shape
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2FwdGpuKernelMod, int, int, int64_t)
|
||||
// dynamic shape ends
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
indices_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
|
||||
if (is_dynamic_shape_) {
|
||||
axis_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
input_shapes_ = inputs[kIndexZero]->GetShapeVector();
|
||||
indices_shapes_ = inputs[kIndexOne]->GetShapeVector();
|
||||
output_shapes_ = outputs[kIndexZero]->GetShapeVector();
|
||||
auto it_x = std::find_if(input_shapes_.begin(), input_shapes_.end(), [](int64_t sh) { return sh <= 0; });
|
||||
auto it_y = std::find_if(indices_shapes_.begin(), indices_shapes_.end(), [](int64_t sh) { return sh <= 0; });
|
||||
if (it_x != input_shapes_.end() || it_y != indices_shapes_.end()) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
|
||||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name_, "indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
if (!is_dynamic_shape_) {
|
||||
int dims = SizeToInt(input_shapes_.size());
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
}
|
||||
Reshape();
|
||||
}
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GatherV2FwdGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const auto &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename G>
|
||||
bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, kIndex1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
if (is_dynamic_shape_) {
|
||||
G *axis_device_address = GetDeviceAddress<G>(inputs, kIndex2); // only get this if in dynamic mode
|
||||
G axis = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&axis, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpy seq_lengths from device to host failed.");
|
||||
axis_ = static_cast<int>(axis);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
||||
Reshape();
|
||||
}
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
MS_EXCEPTION_IF_NULL(indices_addr);
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2],
|
||||
LongToSize(input_dim1), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, GatherV2FwdGpuKernelMod::GatherV2Func>> GatherV2FwdGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<float>, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<float>, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<double>, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<double>, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int16_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int16_t, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int8_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int8_t, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<uint, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<uint, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<uint8_t, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<uint8_t, int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int64_t, int64_t>},
|
||||
// dynamic shape
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int, int64_t>},
|
||||
};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Gather, GatherV2FwdGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
@ -27,81 +29,27 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S, typename G>
|
||||
class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
GatherV2FwdGpuKernelMod() { ResetResource(); }
|
||||
~GatherV2FwdGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
if (is_dynamic_shape_) {
|
||||
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync axis_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
|
||||
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
||||
Reshape();
|
||||
}
|
||||
auto input_dim1 = input_shapes_[axis_];
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
MS_EXCEPTION_IF_NULL(indices_addr);
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], LongToSize(input_dim1),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == 3) {
|
||||
is_dynamic_shape_ = true;
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
|
||||
} else if (input_num == 2) {
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||
}
|
||||
input_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
|
||||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name, "indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (!is_dynamic_shape_) {
|
||||
int dims = SizeToInt(input_shapes_.size());
|
||||
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
}
|
||||
Reshape();
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
is_dynamic_shape_ = false;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void ResetResource() noexcept {
|
||||
input_shapes_.clear();
|
||||
indices_shapes_.clear();
|
||||
output_shapes_.clear();
|
||||
std::fill(dims_, dims_ + 3, 0);
|
||||
axis_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
|
@ -109,19 +57,23 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t size = common::AnfAlgo::TensorSizeInByte<T>(input_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
void InitSizeLists() {
|
||||
auto input_size = std::accumulate(input_shapes_.begin(), input_shapes_.end(), 1, std::multiplies{});
|
||||
auto indices_size = std::accumulate(indices_shapes_.begin(), indices_shapes_.end(), 1, std::multiplies{});
|
||||
input_size_list_.push_back(LongToSize(input_size) * input_type_size_);
|
||||
input_size_list_.push_back(LongToSize(indices_size) * indices_type_size_);
|
||||
if (is_dynamic_shape_) {
|
||||
input_size_list_.push_back(sizeof(G));
|
||||
input_size_list_.push_back(axis_type_size_);
|
||||
}
|
||||
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
auto output_size = std::accumulate(output_shapes_.begin(), output_shapes_.end(), 1, std::multiplies{});
|
||||
output_size_list_.push_back(LongToSize(output_size) * input_type_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T, typename S, typename G>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
void Reshape() {
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
|
@ -138,19 +90,29 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
|
||||
dim_after_indices *= output_shapes_[i];
|
||||
}
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_of_indices;
|
||||
dims_[2] = dim_after_indices;
|
||||
dims_[kIndex0] = dim_before_axis;
|
||||
dims_[kIndex1] = dim_of_indices;
|
||||
dims_[kIndex2] = dim_after_indices;
|
||||
return;
|
||||
}
|
||||
|
||||
cudaStream_t cuda_stream_;
|
||||
std::vector<int64_t> input_shapes_;
|
||||
std::vector<int64_t> indices_shapes_;
|
||||
std::vector<int64_t> output_shapes_;
|
||||
int64_t dims_[3] = {};
|
||||
G axis_;
|
||||
bool is_dynamic_shape_;
|
||||
bool is_null_input_;
|
||||
int64_t dims_[kIndex3] = {};
|
||||
int axis_ = 0;
|
||||
bool is_dynamic_shape_ = false;
|
||||
bool is_null_input_ = false;
|
||||
size_t input_type_size_ = 0;
|
||||
size_t indices_type_size_ = 0;
|
||||
size_t axis_type_size_ = 0;
|
||||
|
||||
private:
|
||||
using GatherV2Func = std::function<bool(GatherV2FwdGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, GatherV2Func>> func_list_;
|
||||
// static std::vector<std::pair<KernelAttr, GatherV2Func>> sparse_gather_func_list_;
|
||||
GatherV2Func kernel_func_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/sparse_gatherv2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseGatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
SparseGatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseGatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SparseGatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S, typename G>
|
||||
class SparseGatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
SparseGatherV2FwdGpuKernelMod() { ResetResource(); }
|
||||
~SparseGatherV2FwdGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
if (is_dynamic_shape_) {
|
||||
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync axis_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
|
||||
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
||||
Reshape();
|
||||
}
|
||||
auto input_dim1 = input_shapes_[axis_];
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
MS_EXCEPTION_IF_NULL(indices_addr);
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], LongToSize(input_dim1),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == kSizeThree) {
|
||||
is_dynamic_shape_ = true;
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
|
||||
} else if (input_num == kSizeTwo) {
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||
}
|
||||
input_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
|
||||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name, "indices") ||
|
||||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (!is_dynamic_shape_) {
|
||||
int dims = SizeToInt(input_shapes_.size());
|
||||
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
}
|
||||
Reshape();
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
is_dynamic_shape_ = false;
|
||||
input_shapes_.clear();
|
||||
indices_shapes_.clear();
|
||||
output_shapes_.clear();
|
||||
std::fill(dims_, dims_ + kSizeThree, 0);
|
||||
axis_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t size = common::AnfAlgo::TensorSizeInByte<T>(input_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
if (is_dynamic_shape_) {
|
||||
input_size_list_.push_back(sizeof(G));
|
||||
}
|
||||
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
}
|
||||
int64_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
int64_t dim_of_indices = 1;
|
||||
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
||||
dim_of_indices *= indices_shapes_[i];
|
||||
}
|
||||
int64_t dim_after_indices = 1;
|
||||
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
|
||||
dim_after_indices *= output_shapes_[i];
|
||||
}
|
||||
dims_[kIndex0] = dim_before_axis;
|
||||
dims_[kIndex1] = dim_of_indices;
|
||||
dims_[kIndex2] = dim_after_indices;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_shapes_;
|
||||
std::vector<int64_t> indices_shapes_;
|
||||
std::vector<int64_t> output_shapes_;
|
||||
int64_t dims_[kIndex3] = {};
|
||||
G axis_;
|
||||
bool is_dynamic_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_
|
Loading…
Reference in New Issue