forked from mindspore-Ecosystem/mindspore
!36106 Add vmap feature for IndexFill operator.
Merge pull request !36106 from hezhenhao1/add_vmap
This commit is contained in:
commit
85131d20c2
|
@ -21,7 +21,6 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kIndexFillInputsNum = 4;
|
||||
constexpr size_t kIndexFillOutputsNum = 1;
|
||||
using kIndexType = int;
|
||||
} // namespace
|
||||
|
||||
bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -39,70 +38,61 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
return true;
|
||||
}
|
||||
|
||||
void IndexFillGpuKernelMod::ResetResource() noexcept {
|
||||
x_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
void IndexFillGpuKernelMod::UpdateSize(const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &) {
|
||||
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
int64_t init = 1;
|
||||
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), init, std::multiplies{});
|
||||
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), init, std::multiplies{});
|
||||
}
|
||||
|
||||
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
ResetResource();
|
||||
int ret = KRET_OK;
|
||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
|
||||
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
|
||||
|
||||
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
|
||||
UpdateSize(inputs, outputs);
|
||||
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
||||
return ret;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename DataType, typename DimType>
|
||||
template <typename DataType, typename IndexType>
|
||||
bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (x_num_ == 0) {
|
||||
return true;
|
||||
}
|
||||
auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(x_ptr);
|
||||
auto dim_ptr = GetDeviceAddress<DimType>(inputs, kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(dim_ptr);
|
||||
auto index_ptr = GetDeviceAddress<kIndexType>(inputs, kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(index_ptr);
|
||||
auto dim_ptr = inputs[kIndex1]->addr;
|
||||
auto index_ptr = GetDeviceAddress<IndexType>(inputs, kIndex2);
|
||||
auto value_ptr = GetDeviceAddress<DataType>(inputs, kIndex3);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(y_ptr);
|
||||
auto out_bound_ptr = GetDeviceAddress<bool>(workspace, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(out_bound_ptr);
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto any = [](auto &&... args) -> bool { return ((args == nullptr) || ...); };
|
||||
if (any(x_ptr, dim_ptr, index_ptr, value_ptr, y_ptr, out_bound_ptr, cuda_stream)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy from 'x' into 'y'.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(y_ptr, x_ptr, x_num_ * sizeof(DataType), cudaMemcpyDeviceToDevice, cuda_stream),
|
||||
"cudaMemcpyAsync output 'y' from 'x' failed.");
|
||||
"In IndexFill kernel, cudaMemcpyAsync output 'y' from 'x' failed.");
|
||||
if (index_num_ == 0) {
|
||||
return true;
|
||||
}
|
||||
// Initialize out_bound_ptr.
|
||||
bool out_bound = false;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(out_bound_ptr, &out_bound, sizeof(bool), cudaMemcpyHostToDevice, cuda_stream),
|
||||
"cudaMemcpyAsync out_bound variable failed.");
|
||||
// Initialize and check 'dim'.
|
||||
DimType dim, rank;
|
||||
dim = rank = static_cast<DimType>(x_shape_.size());
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, sizeof(DimType), cudaMemcpyDeviceToHost),
|
||||
"cudaMemcpy input 'dim' device to host failed.");
|
||||
int rank = static_cast<int>(x_shape_.size());
|
||||
int dim;
|
||||
if (inputs[kIndex1]->size == sizeof(int)) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
|
||||
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
|
||||
} else {
|
||||
int64_t dim_tmp;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim_tmp, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
|
||||
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
|
||||
dim = static_cast<int>(dim_tmp);
|
||||
}
|
||||
if (dim < -rank || dim >= rank) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << rank << "," << rank
|
||||
<< "), but got " << dim;
|
||||
|
@ -110,12 +100,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
} else if (dim < 0) {
|
||||
dim = dim + rank;
|
||||
}
|
||||
// Initialize out_bound_ptr.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(out_bound_ptr, 0, sizeof(bool), cuda_stream),
|
||||
"In IndexFill kernel, cudaMemsetAsync out_bound variable failed.");
|
||||
// Prepare index_num, dim_size, outer_size, inner_size
|
||||
int64_t dim_size = 1;
|
||||
int64_t outer_size = 1;
|
||||
int64_t inner_size = 1;
|
||||
for (size_t i = 0; i < x_shape_.size(); i++) {
|
||||
auto idx = static_cast<DimType>(i);
|
||||
int idx = static_cast<int>(i);
|
||||
if (idx < dim) {
|
||||
outer_size *= x_shape_.at(i);
|
||||
} else if (idx > dim) {
|
||||
|
@ -124,14 +117,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
dim_size = x_shape_.at(i);
|
||||
}
|
||||
}
|
||||
IndexFill(y_ptr, index_ptr, index_num_, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, device_id_,
|
||||
cuda_stream);
|
||||
|
||||
int64_t real_index_num = index_num_ * (outer_size * inner_size);
|
||||
IndexFill(y_ptr, index_ptr, real_index_num, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, cuda_stream);
|
||||
|
||||
bool out_bound = false;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(&out_bound, out_bound_ptr, sizeof(bool), cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"cudaMemcpyAsync out_bound_ variable failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), "IndexFill cudaStreamSynchronized failed");
|
||||
"In IndexFill kernel, cudaMemcpyAsync out_bound_ variable failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream),
|
||||
"In IndexFill kernel, cudaStreamSynchronized failed");
|
||||
if (out_bound) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the input 'index' is out of bound.";
|
||||
return false;
|
||||
|
@ -217,83 +211,6 @@ std::vector<std::pair<KernelAttr, IndexFillGpuKernelMod::IndexFillLaunchFunc>> I
|
|||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<double, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&IndexFillGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
|
||||
|
@ -304,6 +221,72 @@ std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
|
|||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillGpuKernelMod);
|
||||
class IndexFillVmapGpuKernelMod : public IndexFillGpuKernelMod {
|
||||
public:
|
||||
IndexFillVmapGpuKernelMod() = default;
|
||||
~IndexFillVmapGpuKernelMod() override = default;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) override {
|
||||
batch_rank_ = base_operator->get_batch_rank();
|
||||
if (batch_rank_ <= 0) {
|
||||
return IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
|
||||
} else {
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
batch_size_ = std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_,
|
||||
decltype(input_shape)::value_type(1), std::multiplies{});
|
||||
int ret = IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
|
||||
auto new_inputs = inputs;
|
||||
for (auto &input : new_inputs) {
|
||||
auto shape = input->GetShapeVector();
|
||||
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
|
||||
input->SetShapeVector(new_shape);
|
||||
}
|
||||
auto new_outputs = outputs;
|
||||
for (auto &output : new_outputs) {
|
||||
auto shape = output->GetShapeVector();
|
||||
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
|
||||
output->SetShapeVector(new_shape);
|
||||
}
|
||||
UpdateSize(new_inputs, new_outputs);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (batch_rank_ <= 0) {
|
||||
return IndexFillGpuKernelMod::Launch(inputs, workspace, outputs, stream_ptr);
|
||||
} else {
|
||||
// Initialize address list of inputs and outputs.
|
||||
std::vector<AddressPtr> new_inputs;
|
||||
std::vector<AddressPtr> new_outputs;
|
||||
(void)std::transform(
|
||||
inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
||||
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
|
||||
(void)std::transform(
|
||||
outputs.begin(), outputs.end(), std::back_inserter(new_outputs),
|
||||
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
|
||||
for (int64_t i = 0; i < batch_size_; i++) {
|
||||
if (!IndexFillGpuKernelMod::Launch(new_inputs, workspace, new_outputs, stream_ptr)) {
|
||||
return false;
|
||||
}
|
||||
(void)std::for_each(new_inputs.begin(), new_inputs.end(), [](auto &ptr) {
|
||||
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
|
||||
});
|
||||
(void)std::for_each(new_outputs.begin(), new_outputs.end(), [](auto &ptr) {
|
||||
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t batch_rank_{0};
|
||||
int64_t batch_size_{1};
|
||||
};
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillVmapGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
|
@ -46,11 +47,10 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
|||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void UpdateSize(const std::vector<KernelTensorPtr> &, const std::vector<KernelTensorPtr> &);
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
@ -60,8 +60,8 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
|||
const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, IndexFillLaunchFunc>> func_list_;
|
||||
IndexFillLaunchFunc kernel_func_;
|
||||
int64_t x_num_;
|
||||
int64_t index_num_;
|
||||
int64_t x_num_{0};
|
||||
int64_t index_num_{0};
|
||||
std::vector<int64_t> x_shape_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -22,7 +22,7 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
|
|||
int64_t stride2) {
|
||||
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < index_size; tid += blockDim.x * gridDim.x) {
|
||||
// Each index must be [-dim_size, dim_size)
|
||||
int index = index_ptr[tid / stride1];
|
||||
int64_t index = static_cast<int64_t>(index_ptr[tid / stride1]);
|
||||
if (index < -dim_size || index >= dim_size) {
|
||||
*out_bound_ptr = true;
|
||||
break;
|
||||
|
@ -39,50 +39,56 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
|
|||
|
||||
template <typename DataType>
|
||||
void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, int64_t dim_size,
|
||||
int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream) {
|
||||
int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
int64_t stride1 = outer_size * inner_size;
|
||||
int64_t stride2 = dim_size * inner_size;
|
||||
IndexFillKernel<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
out_ptr, index_ptr, index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2);
|
||||
int64_t real_index_size = index_size * (outer_size * inner_size);
|
||||
IndexFillKernel<<<CUDA_BLOCKS(device_id, real_index_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
out_ptr, index_ptr, real_index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void IndexFill<uint8_t>(uint8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const uint8_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<uint16_t>(uint16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const uint16_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<uint32_t>(uint32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const uint32_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<uint64_t>(uint64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const uint64_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<int8_t>(int8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const int8_t *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
||||
const int8_t *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<int16_t>(int16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const int16_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<int32_t>(int32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const int32_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<int64_t>(int64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const int64_t *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<half>(half *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const half *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
||||
const half *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<float>(float *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const float *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
||||
const float *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *index_ptr, int64_t index_size,
|
||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||
const double *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
||||
const double *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -22,6 +22,6 @@
|
|||
template <typename DataType>
|
||||
CUDA_LIB_EXPORT void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size,
|
||||
int64_t dim_size, int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INDEX_Fill_IMPL_CUH_
|
||||
|
|
|
@ -58,24 +58,30 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
|||
abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
size_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
auto value_ptr = primitive->GetAttr(kBatchRank);
|
||||
batch_rank = GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
// Input 'dim' must be a tensor with a value or a scalar.
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
auto dim_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto dim_rank = SizeToLong(dim_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, batch_rank, prim_name);
|
||||
}
|
||||
|
||||
// Input 'index' must be a scalar/vector.
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto index_rank = SizeToLong(index_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {batch_rank, batch_rank + 1},
|
||||
prim_name);
|
||||
|
||||
// Input 'value' must be a tensor with a value or a scalar.
|
||||
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
|
||||
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto value_rank = SizeToLong(value_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, batch_rank, prim_name);
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
|
|
|
@ -31,6 +31,7 @@ from ..operations.array_ops import Fills
|
|||
from ..operations.array_ops import UniqueConsecutive
|
||||
from ..operations.array_ops import Col2Im
|
||||
from ..operations.array_ops import NonZero
|
||||
from ..operations.array_ops import IndexFill
|
||||
|
||||
|
||||
@vmap_rules_getters.register("Cast")
|
||||
|
@ -1142,6 +1143,38 @@ def get_gather_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(IndexFill)
|
||||
def get_index_fill_rule(prim, axis_size):
|
||||
"""VmapRule for `IndexFill` operation."""
|
||||
if hasattr(prim, 'batch_rank'):
|
||||
batch_rank = prim.batch_rank + 1
|
||||
else:
|
||||
batch_rank = 1
|
||||
|
||||
batch_prim = IndexFill()
|
||||
batch_prim.add_prim_attr('batch_rank', batch_rank)
|
||||
|
||||
def vmap_rule(x_bdim, dim_bdim, index_bdim, value_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, value_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
dim, dim_dim = dim_bdim
|
||||
index, index_dim = index_bdim
|
||||
value, value_dim = value_bdim
|
||||
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
dim = _bdim_at_front(dim, dim_dim, axis_size)
|
||||
index = _bdim_at_front(index, index_dim, axis_size)
|
||||
value = _bdim_at_front(value, value_dim, axis_size)
|
||||
|
||||
out = batch_prim(x, dim, index, value)
|
||||
return out, 0
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.DataFormatDimMap)
|
||||
def get_pdist_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `DataFormatDimMap`"""
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
class IndexFillNet(nn.Cell):
|
||||
|
@ -32,18 +33,7 @@ class IndexFillNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def compare_with_numpy(x, dim, index, value):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
ms_x = Tensor(x)
|
||||
ms_dim = dim
|
||||
ms_index = Tensor(index)
|
||||
ms_value = value
|
||||
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||
# PyNative Mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||
|
||||
# Numpy
|
||||
def numpy_index_fill(x, dim, index, value):
|
||||
np_result = x.copy()
|
||||
if dim == 0:
|
||||
np_result[index] = value
|
||||
|
@ -53,6 +43,23 @@ def compare_with_numpy(x, dim, index, value):
|
|||
np_result[:, :, index] = value
|
||||
else:
|
||||
raise ValueError("dim must be 0, 1 or 2")
|
||||
return np_result
|
||||
|
||||
|
||||
def compare_with_numpy(x, dim, index, value):
|
||||
# Graph Mode
|
||||
ms_x = Tensor(x)
|
||||
ms_dim = dim
|
||||
ms_index = Tensor(index)
|
||||
ms_value = value
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||
# PyNative Mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||
|
||||
# Numpy
|
||||
np_result = numpy_index_fill(x, dim, index, value)
|
||||
|
||||
return np.allclose(ms_result_graph, np_result) and np.allclose(ms_result_pynative, np_result)
|
||||
|
||||
|
@ -60,7 +67,7 @@ def compare_with_numpy(x, dim, index, value):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64])
|
||||
@pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.float16, np.float64])
|
||||
def test_index_fill_data_type(data_type):
|
||||
"""
|
||||
Feature: IndexFill
|
||||
|
@ -127,3 +134,43 @@ def test_index_fill_error(dim, data_type):
|
|||
|
||||
with pytest.raises(RuntimeError):
|
||||
IndexFillNet()(ms_x, ms_dim, ms_index, ms_value)
|
||||
|
||||
|
||||
class IndexFillVmapNet(nn.Cell):
|
||||
def __init__(self, net, in_axes):
|
||||
super(IndexFillVmapNet, self).__init__()
|
||||
self.net = net
|
||||
self.vmap_index_fill = vmap(self.net, in_axes=in_axes, out_axes=0)
|
||||
|
||||
def construct(self, x, dim, index, value):
|
||||
out = self.vmap_index_fill(x, dim, index, value)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_fill_vmap():
|
||||
"""
|
||||
Feature: IndexFill
|
||||
Description: test cases of vmap for IndexFill operator.
|
||||
Expectation: the result match numpy.
|
||||
"""
|
||||
data_type = np.float32
|
||||
dim_type = np.int32
|
||||
dim = Tensor(np.array([0, 1], dtype=dim_type))
|
||||
value = Tensor(np.array([-20, -10], dtype=data_type))
|
||||
x_np = np.random.random(size=(2, 5, 5)).astype(data_type)
|
||||
index_np = np.random.randint(low=0, high=5, size=(2, 4)).astype(np.int32)
|
||||
|
||||
# MindSpore
|
||||
ms_x = Tensor(x_np)
|
||||
ms_index = Tensor(index_np)
|
||||
ms_result_graph = IndexFillVmapNet(IndexFillNet(), (0, 0, 0, 0))(ms_x, dim, ms_index, value).asnumpy()
|
||||
|
||||
# NumPy
|
||||
np_result = [None, None]
|
||||
for i in range(2):
|
||||
np_result[i] = numpy_index_fill(x_np[i], dim[i], index_np[i], value[i])
|
||||
np_result = np.asarray(np_result)
|
||||
assert np.allclose(ms_result_graph, np_result)
|
||||
|
|
Loading…
Reference in New Issue