Add vmap feature for IndexFill operator.

This commit is contained in:
hezhenhao1 2022-06-17 10:06:53 +08:00
parent daec6916ca
commit 462f7b03d1
7 changed files with 234 additions and 159 deletions

View File

@ -21,7 +21,6 @@ namespace kernel {
namespace { namespace {
constexpr size_t kIndexFillInputsNum = 4; constexpr size_t kIndexFillInputsNum = 4;
constexpr size_t kIndexFillOutputsNum = 1; constexpr size_t kIndexFillOutputsNum = 1;
using kIndexType = int;
} // namespace } // namespace
bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -39,70 +38,61 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
return true; return true;
} }
void IndexFillGpuKernelMod::ResetResource() noexcept { void IndexFillGpuKernelMod::UpdateSize(const std::vector<KernelTensorPtr> &inputs,
x_shape_.clear(); const std::vector<KernelTensorPtr> &) {
input_size_list_.clear(); x_shape_ = inputs.at(kIndex0)->GetShapeVector();
output_size_list_.clear(); auto index_shape = inputs.at(kIndex2)->GetShapeVector();
workspace_size_list_.clear(); int64_t init = 1;
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), init, std::multiplies{});
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), init, std::multiplies{});
} }
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) { const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource(); if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
int ret = KRET_OK;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
return ret; return ret;
} }
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_); UpdateSize(inputs, outputs);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound. workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
return ret; return KRET_OK;
} }
template <typename DataType, typename DimType> template <typename DataType, typename IndexType>
bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) { const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (x_num_ == 0) {
return true;
}
auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0); auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
MS_EXCEPTION_IF_NULL(x_ptr); auto dim_ptr = inputs[kIndex1]->addr;
auto dim_ptr = GetDeviceAddress<DimType>(inputs, kIndex1); auto index_ptr = GetDeviceAddress<IndexType>(inputs, kIndex2);
MS_EXCEPTION_IF_NULL(dim_ptr);
auto index_ptr = GetDeviceAddress<kIndexType>(inputs, kIndex2);
MS_EXCEPTION_IF_NULL(index_ptr);
auto value_ptr = GetDeviceAddress<DataType>(inputs, kIndex3); auto value_ptr = GetDeviceAddress<DataType>(inputs, kIndex3);
MS_EXCEPTION_IF_NULL(value_ptr);
auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0); auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
MS_EXCEPTION_IF_NULL(y_ptr);
auto out_bound_ptr = GetDeviceAddress<bool>(workspace, kIndex0); auto out_bound_ptr = GetDeviceAddress<bool>(workspace, kIndex0);
MS_EXCEPTION_IF_NULL(out_bound_ptr);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr); auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
auto any = [](auto &&... args) -> bool { return ((args == nullptr) || ...); };
if (any(x_ptr, dim_ptr, index_ptr, value_ptr, y_ptr, out_bound_ptr, cuda_stream)) {
return false;
}
// Copy from 'x' into 'y'. // Copy from 'x' into 'y'.
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(y_ptr, x_ptr, x_num_ * sizeof(DataType), cudaMemcpyDeviceToDevice, cuda_stream), cudaMemcpyAsync(y_ptr, x_ptr, x_num_ * sizeof(DataType), cudaMemcpyDeviceToDevice, cuda_stream),
"cudaMemcpyAsync output 'y' from 'x' failed."); "In IndexFill kernel, cudaMemcpyAsync output 'y' from 'x' failed.");
if (index_num_ == 0) { if (index_num_ == 0) {
return true; return true;
} }
// Initialize out_bound_ptr.
bool out_bound = false;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(out_bound_ptr, &out_bound, sizeof(bool), cudaMemcpyHostToDevice, cuda_stream),
"cudaMemcpyAsync out_bound variable failed.");
// Initialize and check 'dim'. // Initialize and check 'dim'.
DimType dim, rank; int rank = static_cast<int>(x_shape_.size());
dim = rank = static_cast<DimType>(x_shape_.size()); int dim;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, sizeof(DimType), cudaMemcpyDeviceToHost), if (inputs[kIndex1]->size == sizeof(int)) {
"cudaMemcpy input 'dim' device to host failed."); CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
} else {
int64_t dim_tmp;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim_tmp, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
dim = static_cast<int>(dim_tmp);
}
if (dim < -rank || dim >= rank) { if (dim < -rank || dim >= rank) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << rank << "," << rank MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << rank << "," << rank
<< "), but got " << dim; << "), but got " << dim;
@ -110,12 +100,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
} else if (dim < 0) { } else if (dim < 0) {
dim = dim + rank; dim = dim + rank;
} }
// Initialize out_bound_ptr.
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(out_bound_ptr, 0, sizeof(bool), cuda_stream),
"In IndexFill kernel, cudaMemsetAsync out_bound variable failed.");
// Prepare index_num, dim_size, outer_size, inner_size // Prepare index_num, dim_size, outer_size, inner_size
int64_t dim_size = 1; int64_t dim_size = 1;
int64_t outer_size = 1; int64_t outer_size = 1;
int64_t inner_size = 1; int64_t inner_size = 1;
for (size_t i = 0; i < x_shape_.size(); i++) { for (size_t i = 0; i < x_shape_.size(); i++) {
auto idx = static_cast<DimType>(i); int idx = static_cast<int>(i);
if (idx < dim) { if (idx < dim) {
outer_size *= x_shape_.at(i); outer_size *= x_shape_.at(i);
} else if (idx > dim) { } else if (idx > dim) {
@ -124,14 +117,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
dim_size = x_shape_.at(i); dim_size = x_shape_.at(i);
} }
} }
IndexFill(y_ptr, index_ptr, index_num_, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, device_id_,
cuda_stream);
int64_t real_index_num = index_num_ * (outer_size * inner_size); bool out_bound = false;
IndexFill(y_ptr, index_ptr, real_index_num, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(&out_bound, out_bound_ptr, sizeof(bool), cudaMemcpyDeviceToHost, cuda_stream), cudaMemcpyAsync(&out_bound, out_bound_ptr, sizeof(bool), cudaMemcpyDeviceToHost, cuda_stream),
"cudaMemcpyAsync out_bound_ variable failed."); "In IndexFill kernel, cudaMemcpyAsync out_bound_ variable failed.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), "IndexFill cudaStreamSynchronized failed"); CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream),
"In IndexFill kernel, cudaStreamSynchronized failed");
if (out_bound) { if (out_bound) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the input 'index' is out of bound."; MS_LOG(ERROR) << "For '" << kernel_name_ << "', the input 'index' is out of bound.";
return false; return false;
@ -217,83 +211,6 @@ std::vector<std::pair<KernelAttr, IndexFillGpuKernelMod::IndexFillLaunchFunc>> I
.AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64), .AddOutputAttr(kNumberTypeFloat64),
&IndexFillGpuKernelMod::LaunchKernel<double, int>}, &IndexFillGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&IndexFillGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&IndexFillGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&IndexFillGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&IndexFillGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&IndexFillGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&IndexFillGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&IndexFillGpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&IndexFillGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&IndexFillGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&IndexFillGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&IndexFillGpuKernelMod::LaunchKernel<double, int64_t>},
}; };
std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() { std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
@ -304,6 +221,72 @@ std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
return support_list; return support_list;
} }
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillGpuKernelMod); class IndexFillVmapGpuKernelMod : public IndexFillGpuKernelMod {
public:
IndexFillVmapGpuKernelMod() = default;
~IndexFillVmapGpuKernelMod() override = default;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) override {
batch_rank_ = base_operator->get_batch_rank();
if (batch_rank_ <= 0) {
return IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
} else {
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
batch_size_ = std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_,
decltype(input_shape)::value_type(1), std::multiplies{});
int ret = IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
auto new_inputs = inputs;
for (auto &input : new_inputs) {
auto shape = input->GetShapeVector();
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
input->SetShapeVector(new_shape);
}
auto new_outputs = outputs;
for (auto &output : new_outputs) {
auto shape = output->GetShapeVector();
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
output->SetShapeVector(new_shape);
}
UpdateSize(new_inputs, new_outputs);
return ret;
}
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (batch_rank_ <= 0) {
return IndexFillGpuKernelMod::Launch(inputs, workspace, outputs, stream_ptr);
} else {
// Initialize address list of inputs and outputs.
std::vector<AddressPtr> new_inputs;
std::vector<AddressPtr> new_outputs;
(void)std::transform(
inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
(void)std::transform(
outputs.begin(), outputs.end(), std::back_inserter(new_outputs),
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
for (int64_t i = 0; i < batch_size_; i++) {
if (!IndexFillGpuKernelMod::Launch(new_inputs, workspace, new_outputs, stream_ptr)) {
return false;
}
(void)std::for_each(new_inputs.begin(), new_inputs.end(), [](auto &ptr) {
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
});
(void)std::for_each(new_outputs.begin(), new_outputs.end(), [](auto &ptr) {
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
});
}
return true;
}
}
private:
int64_t batch_rank_{0};
int64_t batch_size_{1};
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillVmapGpuKernelMod);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <functional> #include <functional>
#include <map> #include <map>
#include <utility> #include <utility>
#include <memory>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
@ -46,11 +47,10 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override; const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected: protected:
void UpdateSize(const std::vector<KernelTensorPtr> &, const std::vector<KernelTensorPtr> &);
std::vector<KernelAttr> GetOpSupport() override; std::vector<KernelAttr> GetOpSupport() override;
private: private:
void ResetResource() noexcept;
template <typename DataType, typename IndexType> template <typename DataType, typename IndexType>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr); const std::vector<AddressPtr> &outputs, void *stream_ptr);
@ -60,8 +60,8 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
const std::vector<AddressPtr> &, void *)>; const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, IndexFillLaunchFunc>> func_list_; static std::vector<std::pair<KernelAttr, IndexFillLaunchFunc>> func_list_;
IndexFillLaunchFunc kernel_func_; IndexFillLaunchFunc kernel_func_;
int64_t x_num_; int64_t x_num_{0};
int64_t index_num_; int64_t index_num_{0};
std::vector<int64_t> x_shape_{}; std::vector<int64_t> x_shape_{};
}; };
} // namespace kernel } // namespace kernel

View File

@ -22,7 +22,7 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
int64_t stride2) { int64_t stride2) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < index_size; tid += blockDim.x * gridDim.x) { for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < index_size; tid += blockDim.x * gridDim.x) {
// Each index must be [-dim_size, dim_size) // Each index must be [-dim_size, dim_size)
int index = index_ptr[tid / stride1]; int64_t index = static_cast<int64_t>(index_ptr[tid / stride1]);
if (index < -dim_size || index >= dim_size) { if (index < -dim_size || index >= dim_size) {
*out_bound_ptr = true; *out_bound_ptr = true;
break; break;
@ -39,50 +39,56 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
template <typename DataType> template <typename DataType>
void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, int64_t dim_size, void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, int64_t dim_size,
int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream) { int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream) {
int64_t stride1 = outer_size * inner_size; int64_t stride1 = outer_size * inner_size;
int64_t stride2 = dim_size * inner_size; int64_t stride2 = dim_size * inner_size;
IndexFillKernel<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>( int64_t real_index_size = index_size * (outer_size * inner_size);
out_ptr, index_ptr, index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2); IndexFillKernel<<<CUDA_BLOCKS(device_id, real_index_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
out_ptr, index_ptr, real_index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2);
} }
template CUDA_LIB_EXPORT void IndexFill<uint8_t>(uint8_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<uint8_t>(uint8_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const uint8_t *value_ptr, bool *out_bound_ptr, const uint8_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<uint16_t>(uint16_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<uint16_t>(uint16_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const uint16_t *value_ptr, bool *out_bound_ptr, const uint16_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<uint32_t>(uint32_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<uint32_t>(uint32_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const uint32_t *value_ptr, bool *out_bound_ptr, const uint32_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<uint64_t>(uint64_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<uint64_t>(uint64_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const uint64_t *value_ptr, bool *out_bound_ptr, const uint64_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<int8_t>(int8_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<int8_t>(int8_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const int8_t *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream); const int8_t *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<int16_t>(int16_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<int16_t>(int16_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const int16_t *value_ptr, bool *out_bound_ptr, const int16_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<int32_t>(int32_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<int32_t>(int32_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const int32_t *value_ptr, bool *out_bound_ptr, const int32_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<int64_t>(int64_t *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<int64_t>(int64_t *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const int64_t *value_ptr, bool *out_bound_ptr, const int64_t *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<half>(half *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<half>(half *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const half *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream); const half *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<float>(float *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<float>(float *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const float *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream); const float *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *index_ptr, int64_t index_size, template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *index_ptr, int64_t index_size,
int64_t outer_size, int64_t dim_size, int64_t inner_size, int64_t outer_size, int64_t dim_size, int64_t inner_size,
const double *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream); const double *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -22,6 +22,6 @@
template <typename DataType> template <typename DataType>
CUDA_LIB_EXPORT void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, CUDA_LIB_EXPORT void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size,
int64_t dim_size, int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, int64_t dim_size, int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr,
cudaStream_t cuda_stream); const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INDEX_Fill_IMPL_CUH_ #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INDEX_Fill_IMPL_CUH_

View File

@ -58,24 +58,30 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<Abst
abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
size_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
auto value_ptr = primitive->GetAttr(kBatchRank);
batch_rank = GetValue<int64_t>(value_ptr);
}
// Input 'dim' must be a tensor with a value or a scalar. // Input 'dim' must be a tensor with a value or a scalar.
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) { if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
auto dim_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto dim_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dim_rank = SizeToLong(dim_shape.size()); auto dim_rank = SizeToLong(dim_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, 0, prim_name); (void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, batch_rank, prim_name);
} }
// Input 'index' must be a scalar/vector. // Input 'index' must be a scalar/vector.
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto index_rank = SizeToLong(index_shape.size()); auto index_rank = SizeToLong(index_shape.size());
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name); (void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {batch_rank, batch_rank + 1},
prim_name);
// Input 'value' must be a tensor with a value or a scalar. // Input 'value' must be a tensor with a value or a scalar.
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) { if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto value_rank = SizeToLong(value_shape.size()); auto value_rank = SizeToLong(value_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, 0, prim_name); (void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, batch_rank, prim_name);
} }
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];

View File

@ -31,6 +31,7 @@ from ..operations.array_ops import Fills
from ..operations.array_ops import UniqueConsecutive from ..operations.array_ops import UniqueConsecutive
from ..operations.array_ops import Col2Im from ..operations.array_ops import Col2Im
from ..operations.array_ops import NonZero from ..operations.array_ops import NonZero
from ..operations.array_ops import IndexFill
@vmap_rules_getters.register("Cast") @vmap_rules_getters.register("Cast")
@ -1098,6 +1099,38 @@ def get_gather_vmap_rule(prim, axis_size):
return vmap_rule return vmap_rule
@vmap_rules_getters.register(IndexFill)
def get_index_fill_rule(prim, axis_size):
"""VmapRule for `IndexFill` operation."""
if hasattr(prim, 'batch_rank'):
batch_rank = prim.batch_rank + 1
else:
batch_rank = 1
batch_prim = IndexFill()
batch_prim.add_prim_attr('batch_rank', batch_rank)
def vmap_rule(x_bdim, dim_bdim, index_bdim, value_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, value_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
dim, dim_dim = dim_bdim
index, index_dim = index_bdim
value, value_dim = value_bdim
x = _bdim_at_front(x, x_dim, axis_size)
dim = _bdim_at_front(dim, dim_dim, axis_size)
index = _bdim_at_front(index, index_dim, axis_size)
value = _bdim_at_front(value, value_dim, axis_size)
out = batch_prim(x, dim, index, value)
return out, 0
return vmap_rule
@vmap_rules_getters.register(P.DataFormatDimMap) @vmap_rules_getters.register(P.DataFormatDimMap)
def get_pdist_vmap_rule(prim, axis_size): def get_pdist_vmap_rule(prim, axis_size):
"""VmapRule for `DataFormatDimMap`""" """VmapRule for `DataFormatDimMap`"""

View File

@ -20,6 +20,7 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.ops as ops import mindspore.ops as ops
from mindspore.ops.functional import vmap
class IndexFillNet(nn.Cell): class IndexFillNet(nn.Cell):
@ -32,18 +33,7 @@ class IndexFillNet(nn.Cell):
return out return out
def compare_with_numpy(x, dim, index, value): def numpy_index_fill(x, dim, index, value):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_x = Tensor(x)
ms_dim = dim
ms_index = Tensor(index)
ms_value = value
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
# PyNative Mode
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
# Numpy
np_result = x.copy() np_result = x.copy()
if dim == 0: if dim == 0:
np_result[index] = value np_result[index] = value
@ -53,6 +43,23 @@ def compare_with_numpy(x, dim, index, value):
np_result[:, :, index] = value np_result[:, :, index] = value
else: else:
raise ValueError("dim must be 0, 1 or 2") raise ValueError("dim must be 0, 1 or 2")
return np_result
def compare_with_numpy(x, dim, index, value):
# Graph Mode
ms_x = Tensor(x)
ms_dim = dim
ms_index = Tensor(index)
ms_value = value
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
# PyNative Mode
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
# Numpy
np_result = numpy_index_fill(x, dim, index, value)
return np.allclose(ms_result_graph, np_result) and np.allclose(ms_result_pynative, np_result) return np.allclose(ms_result_graph, np_result) and np.allclose(ms_result_pynative, np_result)
@ -60,7 +67,7 @@ def compare_with_numpy(x, dim, index, value):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]) @pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.float16, np.float64])
def test_index_fill_data_type(data_type): def test_index_fill_data_type(data_type):
""" """
Feature: IndexFill Feature: IndexFill
@ -127,3 +134,43 @@ def test_index_fill_error(dim, data_type):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
IndexFillNet()(ms_x, ms_dim, ms_index, ms_value) IndexFillNet()(ms_x, ms_dim, ms_index, ms_value)
class IndexFillVmapNet(nn.Cell):
def __init__(self, net, in_axes):
super(IndexFillVmapNet, self).__init__()
self.net = net
self.vmap_index_fill = vmap(self.net, in_axes=in_axes, out_axes=0)
def construct(self, x, dim, index, value):
out = self.vmap_index_fill(x, dim, index, value)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_index_fill_vmap():
"""
Feature: IndexFill
Description: test cases of vmap for IndexFill operator.
Expectation: the result match numpy.
"""
data_type = np.float32
dim_type = np.int32
dim = Tensor(np.array([0, 1], dtype=dim_type))
value = Tensor(np.array([-20, -10], dtype=data_type))
x_np = np.random.random(size=(2, 5, 5)).astype(data_type)
index_np = np.random.randint(low=0, high=5, size=(2, 4)).astype(np.int32)
# MindSpore
ms_x = Tensor(x_np)
ms_index = Tensor(index_np)
ms_result_graph = IndexFillVmapNet(IndexFillNet(), (0, 0, 0, 0))(ms_x, dim, ms_index, value).asnumpy()
# NumPy
np_result = [None, None]
for i in range(2):
np_result[i] = numpy_index_fill(x_np[i], dim[i], index_np[i], value[i])
np_result = np.asarray(np_result)
assert np.allclose(ms_result_graph, np_result)