!41395 gather_ops_support_dynamic_shape

Merge pull request !41395 from yao_yf/gather_ops_support_dynamic_shape
This commit is contained in:
i-robot 2022-09-05 02:17:06 +00:00 committed by Gitee
commit adf8ec5fb8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 531 additions and 288 deletions

View File

@ -32,41 +32,57 @@ constexpr size_t kGatherInputsNum = 2;
constexpr size_t kGatherOutputsNum = 1;
constexpr size_t kGatherInputParamsMaxDim = 7;
} // namespace
void GatherCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
bool GatherCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
size_t input_num = inputs.size();
if (input_num == kGatherInputsNum + 1) {
is_dynamic_shape_ = true;
MS_LOG(DEBUG) << " GatherCPUKernel running in Dynamic Mode.";
} else if (input_num == kGatherInputsNum) {
axis_ = GetValue<int64_t>(base_operator->GetAttr("axis"));
MS_LOG(DEBUG) << " GatherCPUKernel running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherCPUKernel needs 2.";
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
indices_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
if (is_dynamic_shape_) {
axis_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
}
return true;
}
void GatherCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
int GatherCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource();
input_shape_ = inputs[kIndexZero]->GetShapeVector();
indices_shape_ = inputs[kIndexOne]->GetShapeVector();
output_shape_ = outputs[kIndexZero]->GetShapeVector();
is_null_input_ = input_shape_.empty() || indices_shape_.empty() || output_shape_.empty();
if (is_null_input_) {
InitSizeLists();
return KRET_OK;
}
if (input_shape_.size() > kGatherInputParamsMaxDim) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_params' should be "
<< kGatherInputParamsMaxDim << "D or lower, but got " << input_shape_.size() << ".";
}
if (!is_dynamic_shape_) {
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
auto it_x = std::find_if(input_shape_.begin(), input_shape_.end(), [](int64_t sh) { return sh <= 0; });
auto it_y = std::find_if(indices_shape_.begin(), indices_shape_.end(), [](int64_t sh) { return sh <= 0; });
if (it_x != input_shape_.end() || it_y != indices_shape_.end()) {
return KRET_UNKNOWN_SHAPE;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Gather does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
node_wpt_ = kernel_node;
InitSizeLists();
return KRET_OK;
}
template <typename T>
@ -76,19 +92,9 @@ bool GatherCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
const auto *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
const auto *indices_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
auto *output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
if (!node_wpt_.expired()) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
if (inputs.size() == kGatherInputsNum) {
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(node, AXIS);
} else if (inputs.size() == kGatherInputsNum + 1) {
if (inputs.size() == kGatherInputsNum + 1) {
axis_ = reinterpret_cast<int64_t *>(inputs[kIndex2]->addr)[0];
} else {
MS_LOG(EXCEPTION) << "Gather requires " << kGatherInputsNum << " or " << (kGatherInputsNum + 1)
<< " inputs, but got " << inputs.size();
}
}
int dims = SizeToInt(input_shape_.size());

View File

@ -19,6 +19,7 @@
#include <utility>
#include <vector>
#include <map>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -26,14 +27,17 @@
namespace mindspore {
namespace kernel {
class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class GatherCpuKernelMod : public NativeCpuKernelMod {
public:
GatherCpuKernelMod() = default;
~GatherCpuKernelMod() override = default;
void CheckParam(const CNodePtr &kernel_node);
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
void InitKernel(const CNodePtr &kernel_node) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
@ -42,6 +46,28 @@ class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
void ResetResource() noexcept {
input_shape_.clear();
indices_shape_.clear();
output_shape_.clear();
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void InitSizeLists() {
auto input_size = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies{});
auto indices_size = std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies{});
input_size_list_.push_back(LongToSize(input_size) * input_type_size_);
input_size_list_.push_back(LongToSize(indices_size) * indices_type_size_);
if (is_dynamic_shape_) {
input_size_list_.push_back(axis_type_size_);
}
auto output_size = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies{});
output_size_list_.push_back(LongToSize(output_size) * input_type_size_);
}
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
@ -55,7 +81,10 @@ class GatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
ShapeVector output_shape_;
int64_t axis_{0};
bool is_dynamic_shape_{false};
CNodeWeakPtr node_wpt_;
size_t input_type_size_ = 0;
size_t indices_type_size_ = 0;
size_t axis_type_size_ = 0;
bool is_null_input_ = false;
};
} // namespace kernel
} // namespace mindspore

View File

@ -19,178 +19,222 @@
namespace mindspore {
namespace kernel {
const size_t kStaticInputNum = 2;
const size_t kDynInputNum = 3;
constexpr char GATHER[] = "Gather";
constexpr char GATHERV2[] = "GatherV2";
constexpr char SPARSEGATHERV2[] = "SPARSEGATHERV2";
template <typename T>
using Complex = mindspore::utils::Complex<T>;
bool GatherV2FwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
size_t input_num = inputs.size();
if (input_num == kDynInputNum) {
is_dynamic_shape_ = true;
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
} else if (input_num == kStaticInputNum) {
axis_ = static_cast<int>(GetValue<int64_t>(base_operator->GetAttr("axis")));
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
}
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
GatherV2FwdGpuKernelMod, Complex<float>, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
GatherV2FwdGpuKernelMod, Complex<float>, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
GatherV2FwdGpuKernelMod, Complex<double>, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
GatherV2FwdGpuKernelMod, Complex<double>, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherV2FwdGpuKernelMod, int, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
GatherV2FwdGpuKernelMod, int, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherV2FwdGpuKernelMod, int16_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherV2FwdGpuKernelMod, int16_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherV2FwdGpuKernelMod, int8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherV2FwdGpuKernelMod, int8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherV2FwdGpuKernelMod, uint, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherV2FwdGpuKernelMod, uint, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherV2FwdGpuKernelMod, uint8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherV2FwdGpuKernelMod, uint8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
// dynamic shape
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
indices_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
if (is_dynamic_shape_) {
axis_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
}
return true;
}
int GatherV2FwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
input_shapes_ = inputs[kIndexZero]->GetShapeVector();
indices_shapes_ = inputs[kIndexOne]->GetShapeVector();
output_shapes_ = outputs[kIndexZero]->GetShapeVector();
auto it_x = std::find_if(input_shapes_.begin(), input_shapes_.end(), [](int64_t sh) { return sh <= 0; });
auto it_y = std::find_if(indices_shapes_.begin(), indices_shapes_.end(), [](int64_t sh) { return sh <= 0; });
if (it_x != input_shapes_.end() || it_y != indices_shapes_.end()) {
return KRET_UNKNOWN_SHAPE;
}
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name_, "input") ||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name_, "indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name_, "output");
if (is_null_input_) {
InitSizeLists();
return KRET_OK;
}
if (!is_dynamic_shape_) {
int dims = SizeToInt(input_shapes_.size());
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
Reshape();
}
InitSizeLists();
return KRET_OK;
}
std::vector<KernelAttr> GatherV2FwdGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const auto &pair) { return pair.first; });
return support_list;
}
template <typename T, typename S, typename G>
bool GatherV2FwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
S *indices_addr = GetDeviceAddress<S>(inputs, kIndex1);
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
if (is_dynamic_shape_) {
G *axis_device_address = GetDeviceAddress<G>(inputs, kIndex2); // only get this if in dynamic mode
G axis = 0;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&axis, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpy seq_lengths from device to host failed.");
axis_ = static_cast<int>(axis);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
Reshape();
}
auto input_dim1 = input_shapes_[IntToSize(axis_)];
MS_EXCEPTION_IF_NULL(input_addr);
MS_EXCEPTION_IF_NULL(indices_addr);
GatherV2(input_addr, indices_addr, output_addr, dims_[kIndex0], dims_[kIndex1], dims_[kIndex2],
LongToSize(input_dim1), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
std::vector<std::pair<KernelAttr, GatherV2FwdGpuKernelMod::GatherV2Func>> GatherV2FwdGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<float>, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<float>, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<double>, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
&GatherV2FwdGpuKernelMod::LaunchKernel<Complex<double>, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
&GatherV2FwdGpuKernelMod::LaunchKernel<int16_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
&GatherV2FwdGpuKernelMod::LaunchKernel<int16_t, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&GatherV2FwdGpuKernelMod::LaunchKernel<int8_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&GatherV2FwdGpuKernelMod::LaunchKernel<int8_t, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&GatherV2FwdGpuKernelMod::LaunchKernel<uint, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&GatherV2FwdGpuKernelMod::LaunchKernel<uint, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&GatherV2FwdGpuKernelMod::LaunchKernel<uint8_t, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&GatherV2FwdGpuKernelMod::LaunchKernel<uint8_t, int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int64_t, int64_t>},
// dynamic shape
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<double, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<float, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<half, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
&GatherV2FwdGpuKernelMod::LaunchKernel<bool, int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
GatherV2FwdGpuKernelMod, int, int, int64_t)
// dynamic shape ends
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
&GatherV2FwdGpuKernelMod::LaunchKernel<int, int, int64_t>},
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Gather, GatherV2FwdGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include <algorithm>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
@ -27,81 +29,27 @@
namespace mindspore {
namespace kernel {
template <typename T, typename S, typename G>
class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class GatherV2FwdGpuKernelMod : public NativeGpuKernelMod {
public:
GatherV2FwdGpuKernelMod() { ResetResource(); }
~GatherV2FwdGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_shape_) {
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync axis_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
Reshape();
}
auto input_dim1 = input_shapes_[axis_];
MS_EXCEPTION_IF_NULL(input_addr);
MS_EXCEPTION_IF_NULL(indices_addr);
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], LongToSize(input_dim1),
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 3) {
is_dynamic_shape_ = true;
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
} else if (input_num == 2) {
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2 or 3, but got " << input_num;
}
input_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name, "indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (!is_dynamic_shape_) {
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
Reshape();
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
is_dynamic_shape_ = false;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void ResetResource() noexcept {
input_shapes_.clear();
indices_shapes_.clear();
output_shapes_.clear();
std::fill(dims_, dims_ + 3, 0);
axis_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
@ -109,19 +57,23 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
protected:
void InitSizeLists() override {
size_t size = common::AnfAlgo::TensorSizeInByte<T>(input_shapes_);
input_size_list_.push_back(size);
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
input_size_list_.push_back(size);
std::vector<KernelAttr> GetOpSupport() override;
void InitSizeLists() {
auto input_size = std::accumulate(input_shapes_.begin(), input_shapes_.end(), 1, std::multiplies{});
auto indices_size = std::accumulate(indices_shapes_.begin(), indices_shapes_.end(), 1, std::multiplies{});
input_size_list_.push_back(LongToSize(input_size) * input_type_size_);
input_size_list_.push_back(LongToSize(indices_size) * indices_type_size_);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(G));
input_size_list_.push_back(axis_type_size_);
}
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
output_size_list_.push_back(size);
auto output_size = std::accumulate(output_shapes_.begin(), output_shapes_.end(), 1, std::multiplies{});
output_size_list_.push_back(LongToSize(output_size) * input_type_size_);
}
private:
template <typename T, typename S, typename G>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
void Reshape() {
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shapes_.size());
@ -138,19 +90,29 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_indices;
dims_[2] = dim_after_indices;
dims_[kIndex0] = dim_before_axis;
dims_[kIndex1] = dim_of_indices;
dims_[kIndex2] = dim_after_indices;
return;
}
cudaStream_t cuda_stream_;
std::vector<int64_t> input_shapes_;
std::vector<int64_t> indices_shapes_;
std::vector<int64_t> output_shapes_;
int64_t dims_[3] = {};
G axis_;
bool is_dynamic_shape_;
bool is_null_input_;
int64_t dims_[kIndex3] = {};
int axis_ = 0;
bool is_dynamic_shape_ = false;
bool is_null_input_ = false;
size_t input_type_size_ = 0;
size_t indices_type_size_ = 0;
size_t axis_type_size_ = 0;
private:
using GatherV2Func = std::function<bool(GatherV2FwdGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, GatherV2Func>> func_list_;
// static std::vector<std::pair<KernelAttr, GatherV2Func>> sparse_gather_func_list_;
GatherV2Func kernel_func_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/sparse_gatherv2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
SparseGatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
SparseGatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
SparseGatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
SparseGatherV2FwdGpuKernelMod, half, int, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,158 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gatherv2.cuh"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S, typename G>
class SparseGatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public:
SparseGatherV2FwdGpuKernelMod() { ResetResource(); }
~SparseGatherV2FwdGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_shape_) {
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync axis_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
Reshape();
}
auto input_dim1 = input_shapes_[axis_];
MS_EXCEPTION_IF_NULL(input_addr);
MS_EXCEPTION_IF_NULL(indices_addr);
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], LongToSize(input_dim1),
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == kSizeThree) {
is_dynamic_shape_ = true;
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
} else if (input_num == kSizeTwo) {
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2 or 3, but got " << input_num;
}
input_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shapes_, kernel_name, "input") ||
CHECK_SHAPE_NULL(indices_shapes_, kernel_name, "indices") ||
CHECK_SHAPE_NULL(output_shapes_, kernel_name, "output");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (!is_dynamic_shape_) {
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
Reshape();
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
is_dynamic_shape_ = false;
input_shapes_.clear();
indices_shapes_.clear();
output_shapes_.clear();
std::fill(dims_, dims_ + kSizeThree, 0);
axis_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t size = common::AnfAlgo::TensorSizeInByte<T>(input_shapes_);
input_size_list_.push_back(size);
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
input_size_list_.push_back(size);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(G));
}
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
output_size_list_.push_back(size);
}
private:
void Reshape() {
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shapes_.size());
}
int64_t dim_before_axis = 1;
for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
dim_before_axis *= output_shapes_[i];
}
int64_t dim_of_indices = 1;
for (size_t i = 0; i < indices_shapes_.size(); i++) {
dim_of_indices *= indices_shapes_[i];
}
int64_t dim_after_indices = 1;
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
}
dims_[kIndex0] = dim_before_axis;
dims_[kIndex1] = dim_of_indices;
dims_[kIndex2] = dim_after_indices;
return;
}
std::vector<int64_t> input_shapes_;
std::vector<int64_t> indices_shapes_;
std::vector<int64_t> output_shapes_;
int64_t dims_[kIndex3] = {};
G axis_;
bool is_dynamic_shape_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SPARSE_GATHERV2_GPU_KERNEL_H_