forked from mindspore-Ecosystem/mindspore
Add vmap feature for IndexFill operator.
This commit is contained in:
parent
daec6916ca
commit
462f7b03d1
|
@ -21,7 +21,6 @@ namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kIndexFillInputsNum = 4;
|
constexpr size_t kIndexFillInputsNum = 4;
|
||||||
constexpr size_t kIndexFillOutputsNum = 1;
|
constexpr size_t kIndexFillOutputsNum = 1;
|
||||||
using kIndexType = int;
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
@ -39,70 +38,61 @@ bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IndexFillGpuKernelMod::ResetResource() noexcept {
|
void IndexFillGpuKernelMod::UpdateSize(const std::vector<KernelTensorPtr> &inputs,
|
||||||
x_shape_.clear();
|
const std::vector<KernelTensorPtr> &) {
|
||||||
input_size_list_.clear();
|
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||||
output_size_list_.clear();
|
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||||
workspace_size_list_.clear();
|
int64_t init = 1;
|
||||||
|
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), init, std::multiplies{});
|
||||||
|
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), init, std::multiplies{});
|
||||||
}
|
}
|
||||||
|
|
||||||
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
int IndexFillGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
ResetResource();
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
int ret = KRET_OK;
|
|
||||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
|
UpdateSize(inputs, outputs);
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexFillOutputsNum, kernel_name_);
|
|
||||||
x_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
|
||||||
x_num_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies{});
|
|
||||||
|
|
||||||
auto index_shape = inputs.at(kIndex2)->GetShapeVector();
|
|
||||||
index_num_ = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies{});
|
|
||||||
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
workspace_size_list_.push_back(sizeof(bool)); // Place out_bound.
|
||||||
return ret;
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataType, typename DimType>
|
template <typename DataType, typename IndexType>
|
||||||
bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
const std::vector<AddressPtr> &workspace,
|
const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
if (x_num_ == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
|
auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
|
||||||
MS_EXCEPTION_IF_NULL(x_ptr);
|
auto dim_ptr = inputs[kIndex1]->addr;
|
||||||
auto dim_ptr = GetDeviceAddress<DimType>(inputs, kIndex1);
|
auto index_ptr = GetDeviceAddress<IndexType>(inputs, kIndex2);
|
||||||
MS_EXCEPTION_IF_NULL(dim_ptr);
|
|
||||||
auto index_ptr = GetDeviceAddress<kIndexType>(inputs, kIndex2);
|
|
||||||
MS_EXCEPTION_IF_NULL(index_ptr);
|
|
||||||
auto value_ptr = GetDeviceAddress<DataType>(inputs, kIndex3);
|
auto value_ptr = GetDeviceAddress<DataType>(inputs, kIndex3);
|
||||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
|
||||||
auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
||||||
MS_EXCEPTION_IF_NULL(y_ptr);
|
|
||||||
auto out_bound_ptr = GetDeviceAddress<bool>(workspace, kIndex0);
|
auto out_bound_ptr = GetDeviceAddress<bool>(workspace, kIndex0);
|
||||||
MS_EXCEPTION_IF_NULL(out_bound_ptr);
|
|
||||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||||
|
auto any = [](auto &&... args) -> bool { return ((args == nullptr) || ...); };
|
||||||
|
if (any(x_ptr, dim_ptr, index_ptr, value_ptr, y_ptr, out_bound_ptr, cuda_stream)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Copy from 'x' into 'y'.
|
// Copy from 'x' into 'y'.
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(y_ptr, x_ptr, x_num_ * sizeof(DataType), cudaMemcpyDeviceToDevice, cuda_stream),
|
cudaMemcpyAsync(y_ptr, x_ptr, x_num_ * sizeof(DataType), cudaMemcpyDeviceToDevice, cuda_stream),
|
||||||
"cudaMemcpyAsync output 'y' from 'x' failed.");
|
"In IndexFill kernel, cudaMemcpyAsync output 'y' from 'x' failed.");
|
||||||
if (index_num_ == 0) {
|
if (index_num_ == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
// Initialize out_bound_ptr.
|
|
||||||
bool out_bound = false;
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(out_bound_ptr, &out_bound, sizeof(bool), cudaMemcpyHostToDevice, cuda_stream),
|
|
||||||
"cudaMemcpyAsync out_bound variable failed.");
|
|
||||||
// Initialize and check 'dim'.
|
// Initialize and check 'dim'.
|
||||||
DimType dim, rank;
|
int rank = static_cast<int>(x_shape_.size());
|
||||||
dim = rank = static_cast<DimType>(x_shape_.size());
|
int dim;
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, sizeof(DimType), cudaMemcpyDeviceToHost),
|
if (inputs[kIndex1]->size == sizeof(int)) {
|
||||||
"cudaMemcpy input 'dim' device to host failed.");
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
|
||||||
|
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
|
||||||
|
} else {
|
||||||
|
int64_t dim_tmp;
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&dim_tmp, dim_ptr, inputs[kIndex1]->size, cudaMemcpyDeviceToHost),
|
||||||
|
"In IndexFill kernel, cudaMemcpy input 'dim' device to host failed.");
|
||||||
|
dim = static_cast<int>(dim_tmp);
|
||||||
|
}
|
||||||
if (dim < -rank || dim >= rank) {
|
if (dim < -rank || dim >= rank) {
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << rank << "," << rank
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'dim' must be in the range [-" << rank << "," << rank
|
||||||
<< "), but got " << dim;
|
<< "), but got " << dim;
|
||||||
|
@ -110,12 +100,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
} else if (dim < 0) {
|
} else if (dim < 0) {
|
||||||
dim = dim + rank;
|
dim = dim + rank;
|
||||||
}
|
}
|
||||||
|
// Initialize out_bound_ptr.
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(out_bound_ptr, 0, sizeof(bool), cuda_stream),
|
||||||
|
"In IndexFill kernel, cudaMemsetAsync out_bound variable failed.");
|
||||||
// Prepare index_num, dim_size, outer_size, inner_size
|
// Prepare index_num, dim_size, outer_size, inner_size
|
||||||
int64_t dim_size = 1;
|
int64_t dim_size = 1;
|
||||||
int64_t outer_size = 1;
|
int64_t outer_size = 1;
|
||||||
int64_t inner_size = 1;
|
int64_t inner_size = 1;
|
||||||
for (size_t i = 0; i < x_shape_.size(); i++) {
|
for (size_t i = 0; i < x_shape_.size(); i++) {
|
||||||
auto idx = static_cast<DimType>(i);
|
int idx = static_cast<int>(i);
|
||||||
if (idx < dim) {
|
if (idx < dim) {
|
||||||
outer_size *= x_shape_.at(i);
|
outer_size *= x_shape_.at(i);
|
||||||
} else if (idx > dim) {
|
} else if (idx > dim) {
|
||||||
|
@ -124,14 +117,15 @@ bool IndexFillGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
dim_size = x_shape_.at(i);
|
dim_size = x_shape_.at(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
IndexFill(y_ptr, index_ptr, index_num_, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, device_id_,
|
||||||
|
cuda_stream);
|
||||||
|
|
||||||
int64_t real_index_num = index_num_ * (outer_size * inner_size);
|
bool out_bound = false;
|
||||||
IndexFill(y_ptr, index_ptr, real_index_num, outer_size, dim_size, inner_size, value_ptr, out_bound_ptr, cuda_stream);
|
|
||||||
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(&out_bound, out_bound_ptr, sizeof(bool), cudaMemcpyDeviceToHost, cuda_stream),
|
cudaMemcpyAsync(&out_bound, out_bound_ptr, sizeof(bool), cudaMemcpyDeviceToHost, cuda_stream),
|
||||||
"cudaMemcpyAsync out_bound_ variable failed.");
|
"In IndexFill kernel, cudaMemcpyAsync out_bound_ variable failed.");
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), "IndexFill cudaStreamSynchronized failed");
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream),
|
||||||
|
"In IndexFill kernel, cudaStreamSynchronized failed");
|
||||||
if (out_bound) {
|
if (out_bound) {
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the input 'index' is out of bound.";
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the input 'index' is out of bound.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -217,83 +211,6 @@ std::vector<std::pair<KernelAttr, IndexFillGpuKernelMod::IndexFillLaunchFunc>> I
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<double, int>},
|
&IndexFillGpuKernelMod::LaunchKernel<double, int>},
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt8)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt8)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt16)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt16)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
|
||||||
.AddOutputAttr(kNumberTypeInt8),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt16)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt16)
|
|
||||||
.AddOutputAttr(kNumberTypeInt16),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<half, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<float, int64_t>},
|
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
&IndexFillGpuKernelMod::LaunchKernel<double, int64_t>},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
|
std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
|
||||||
|
@ -304,6 +221,72 @@ std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {
|
||||||
return support_list;
|
return support_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillGpuKernelMod);
|
class IndexFillVmapGpuKernelMod : public IndexFillGpuKernelMod {
|
||||||
|
public:
|
||||||
|
IndexFillVmapGpuKernelMod() = default;
|
||||||
|
~IndexFillVmapGpuKernelMod() override = default;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &others) override {
|
||||||
|
batch_rank_ = base_operator->get_batch_rank();
|
||||||
|
if (batch_rank_ <= 0) {
|
||||||
|
return IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
|
||||||
|
} else {
|
||||||
|
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||||
|
batch_size_ = std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_,
|
||||||
|
decltype(input_shape)::value_type(1), std::multiplies{});
|
||||||
|
int ret = IndexFillGpuKernelMod::Resize(base_operator, inputs, outputs, others);
|
||||||
|
auto new_inputs = inputs;
|
||||||
|
for (auto &input : new_inputs) {
|
||||||
|
auto shape = input->GetShapeVector();
|
||||||
|
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
|
||||||
|
input->SetShapeVector(new_shape);
|
||||||
|
}
|
||||||
|
auto new_outputs = outputs;
|
||||||
|
for (auto &output : new_outputs) {
|
||||||
|
auto shape = output->GetShapeVector();
|
||||||
|
std::vector<int64_t> new_shape(shape.begin() + batch_rank_, shape.end());
|
||||||
|
output->SetShapeVector(new_shape);
|
||||||
|
}
|
||||||
|
UpdateSize(new_inputs, new_outputs);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
if (batch_rank_ <= 0) {
|
||||||
|
return IndexFillGpuKernelMod::Launch(inputs, workspace, outputs, stream_ptr);
|
||||||
|
} else {
|
||||||
|
// Initialize address list of inputs and outputs.
|
||||||
|
std::vector<AddressPtr> new_inputs;
|
||||||
|
std::vector<AddressPtr> new_outputs;
|
||||||
|
(void)std::transform(
|
||||||
|
inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
||||||
|
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
|
||||||
|
(void)std::transform(
|
||||||
|
outputs.begin(), outputs.end(), std::back_inserter(new_outputs),
|
||||||
|
[batch_size = batch_size_](auto ptr) { return std::make_shared<Address>(ptr->addr, ptr->size / batch_size); });
|
||||||
|
for (int64_t i = 0; i < batch_size_; i++) {
|
||||||
|
if (!IndexFillGpuKernelMod::Launch(new_inputs, workspace, new_outputs, stream_ptr)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(void)std::for_each(new_inputs.begin(), new_inputs.end(), [](auto &ptr) {
|
||||||
|
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
|
||||||
|
});
|
||||||
|
(void)std::for_each(new_outputs.begin(), new_outputs.end(), [](auto &ptr) {
|
||||||
|
ptr->addr = reinterpret_cast<void *>(reinterpret_cast<char *>(ptr->addr) + (ptr->size));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t batch_rank_{0};
|
||||||
|
int64_t batch_size_{1};
|
||||||
|
};
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IndexFill, IndexFillVmapGpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh"
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||||
|
@ -46,11 +47,10 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
||||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void UpdateSize(const std::vector<KernelTensorPtr> &, const std::vector<KernelTensorPtr> &);
|
||||||
std::vector<KernelAttr> GetOpSupport() override;
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void ResetResource() noexcept;
|
|
||||||
|
|
||||||
template <typename DataType, typename IndexType>
|
template <typename DataType, typename IndexType>
|
||||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||||
|
@ -60,8 +60,8 @@ class IndexFillGpuKernelMod : public NativeGpuKernelMod {
|
||||||
const std::vector<AddressPtr> &, void *)>;
|
const std::vector<AddressPtr> &, void *)>;
|
||||||
static std::vector<std::pair<KernelAttr, IndexFillLaunchFunc>> func_list_;
|
static std::vector<std::pair<KernelAttr, IndexFillLaunchFunc>> func_list_;
|
||||||
IndexFillLaunchFunc kernel_func_;
|
IndexFillLaunchFunc kernel_func_;
|
||||||
int64_t x_num_;
|
int64_t x_num_{0};
|
||||||
int64_t index_num_;
|
int64_t index_num_{0};
|
||||||
std::vector<int64_t> x_shape_{};
|
std::vector<int64_t> x_shape_{};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -22,7 +22,7 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
|
||||||
int64_t stride2) {
|
int64_t stride2) {
|
||||||
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < index_size; tid += blockDim.x * gridDim.x) {
|
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < index_size; tid += blockDim.x * gridDim.x) {
|
||||||
// Each index must be [-dim_size, dim_size)
|
// Each index must be [-dim_size, dim_size)
|
||||||
int index = index_ptr[tid / stride1];
|
int64_t index = static_cast<int64_t>(index_ptr[tid / stride1]);
|
||||||
if (index < -dim_size || index >= dim_size) {
|
if (index < -dim_size || index >= dim_size) {
|
||||||
*out_bound_ptr = true;
|
*out_bound_ptr = true;
|
||||||
break;
|
break;
|
||||||
|
@ -39,50 +39,56 @@ __global__ void IndexFillKernel(DataType *out_ptr, const int *index_ptr, int64_t
|
||||||
|
|
||||||
template <typename DataType>
|
template <typename DataType>
|
||||||
void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, int64_t dim_size,
|
void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size, int64_t dim_size,
|
||||||
int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream) {
|
int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
int64_t stride1 = outer_size * inner_size;
|
int64_t stride1 = outer_size * inner_size;
|
||||||
int64_t stride2 = dim_size * inner_size;
|
int64_t stride2 = dim_size * inner_size;
|
||||||
IndexFillKernel<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(
|
int64_t real_index_size = index_size * (outer_size * inner_size);
|
||||||
out_ptr, index_ptr, index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2);
|
IndexFillKernel<<<CUDA_BLOCKS(device_id, real_index_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||||
|
out_ptr, index_ptr, real_index_size, dim_size, inner_size, value_ptr, out_bound_ptr, stride1, stride2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template CUDA_LIB_EXPORT void IndexFill<uint8_t>(uint8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<uint8_t>(uint8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const uint8_t *value_ptr, bool *out_bound_ptr,
|
const uint8_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<uint16_t>(uint16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<uint16_t>(uint16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const uint16_t *value_ptr, bool *out_bound_ptr,
|
const uint16_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<uint32_t>(uint32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<uint32_t>(uint32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const uint32_t *value_ptr, bool *out_bound_ptr,
|
const uint32_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<uint64_t>(uint64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<uint64_t>(uint64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const uint64_t *value_ptr, bool *out_bound_ptr,
|
const uint64_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<int8_t>(int8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<int8_t>(int8_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const int8_t *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
const int8_t *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<int16_t>(int16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<int16_t>(int16_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const int16_t *value_ptr, bool *out_bound_ptr,
|
const int16_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<int32_t>(int32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<int32_t>(int32_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const int32_t *value_ptr, bool *out_bound_ptr,
|
const int32_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<int64_t>(int64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<int64_t>(int64_t *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const int64_t *value_ptr, bool *out_bound_ptr,
|
const int64_t *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<half>(half *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<half>(half *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const half *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
const half *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<float>(float *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<float>(float *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const float *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
const float *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *index_ptr, int64_t index_size,
|
template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *index_ptr, int64_t index_size,
|
||||||
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
int64_t outer_size, int64_t dim_size, int64_t inner_size,
|
||||||
const double *value_ptr, bool *out_bound_ptr, cudaStream_t cuda_stream);
|
const double *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -22,6 +22,6 @@
|
||||||
template <typename DataType>
|
template <typename DataType>
|
||||||
CUDA_LIB_EXPORT void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size,
|
CUDA_LIB_EXPORT void IndexFill(DataType *out_ptr, const int *index_ptr, int64_t index_size, int64_t outer_size,
|
||||||
int64_t dim_size, int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr,
|
int64_t dim_size, int64_t inner_size, const DataType *value_ptr, bool *out_bound_ptr,
|
||||||
cudaStream_t cuda_stream);
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INDEX_Fill_IMPL_CUH_
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INDEX_Fill_IMPL_CUH_
|
||||||
|
|
|
@ -58,24 +58,30 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<Abst
|
||||||
abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
|
size_t batch_rank = 0;
|
||||||
|
if (primitive->HasAttr(kBatchRank)) {
|
||||||
|
auto value_ptr = primitive->GetAttr(kBatchRank);
|
||||||
|
batch_rank = GetValue<int64_t>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Input 'dim' must be a tensor with a value or a scalar.
|
// Input 'dim' must be a tensor with a value or a scalar.
|
||||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||||
auto dim_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
auto dim_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||||
auto dim_rank = SizeToLong(dim_shape.size());
|
auto dim_rank = SizeToLong(dim_shape.size());
|
||||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, 0, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("rank of 'dim'", dim_rank, kEqual, batch_rank, prim_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Input 'index' must be a scalar/vector.
|
// Input 'index' must be a scalar/vector.
|
||||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||||
auto index_rank = SizeToLong(index_shape.size());
|
auto index_rank = SizeToLong(index_shape.size());
|
||||||
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name);
|
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {batch_rank, batch_rank + 1},
|
||||||
|
prim_name);
|
||||||
|
|
||||||
// Input 'value' must be a tensor with a value or a scalar.
|
// Input 'value' must be a tensor with a value or a scalar.
|
||||||
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
|
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
|
||||||
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||||
auto value_rank = SizeToLong(value_shape.size());
|
auto value_rank = SizeToLong(value_shape.size());
|
||||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, 0, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("rank of 'value'", value_rank, kEqual, batch_rank, prim_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
|
|
@ -31,6 +31,7 @@ from ..operations.array_ops import Fills
|
||||||
from ..operations.array_ops import UniqueConsecutive
|
from ..operations.array_ops import UniqueConsecutive
|
||||||
from ..operations.array_ops import Col2Im
|
from ..operations.array_ops import Col2Im
|
||||||
from ..operations.array_ops import NonZero
|
from ..operations.array_ops import NonZero
|
||||||
|
from ..operations.array_ops import IndexFill
|
||||||
|
|
||||||
|
|
||||||
@vmap_rules_getters.register("Cast")
|
@vmap_rules_getters.register("Cast")
|
||||||
|
@ -1098,6 +1099,38 @@ def get_gather_vmap_rule(prim, axis_size):
|
||||||
return vmap_rule
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
|
@vmap_rules_getters.register(IndexFill)
|
||||||
|
def get_index_fill_rule(prim, axis_size):
|
||||||
|
"""VmapRule for `IndexFill` operation."""
|
||||||
|
if hasattr(prim, 'batch_rank'):
|
||||||
|
batch_rank = prim.batch_rank + 1
|
||||||
|
else:
|
||||||
|
batch_rank = 1
|
||||||
|
|
||||||
|
batch_prim = IndexFill()
|
||||||
|
batch_prim.add_prim_attr('batch_rank', batch_rank)
|
||||||
|
|
||||||
|
def vmap_rule(x_bdim, dim_bdim, index_bdim, value_bdim):
|
||||||
|
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, value_bdim)
|
||||||
|
if is_all_none:
|
||||||
|
return result
|
||||||
|
|
||||||
|
x, x_dim = x_bdim
|
||||||
|
dim, dim_dim = dim_bdim
|
||||||
|
index, index_dim = index_bdim
|
||||||
|
value, value_dim = value_bdim
|
||||||
|
|
||||||
|
x = _bdim_at_front(x, x_dim, axis_size)
|
||||||
|
dim = _bdim_at_front(dim, dim_dim, axis_size)
|
||||||
|
index = _bdim_at_front(index, index_dim, axis_size)
|
||||||
|
value = _bdim_at_front(value, value_dim, axis_size)
|
||||||
|
|
||||||
|
out = batch_prim(x, dim, index, value)
|
||||||
|
return out, 0
|
||||||
|
|
||||||
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
@vmap_rules_getters.register(P.DataFormatDimMap)
|
@vmap_rules_getters.register(P.DataFormatDimMap)
|
||||||
def get_pdist_vmap_rule(prim, axis_size):
|
def get_pdist_vmap_rule(prim, axis_size):
|
||||||
"""VmapRule for `DataFormatDimMap`"""
|
"""VmapRule for `DataFormatDimMap`"""
|
||||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.ops as ops
|
import mindspore.ops as ops
|
||||||
|
from mindspore.ops.functional import vmap
|
||||||
|
|
||||||
|
|
||||||
class IndexFillNet(nn.Cell):
|
class IndexFillNet(nn.Cell):
|
||||||
|
@ -32,18 +33,7 @@ class IndexFillNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def compare_with_numpy(x, dim, index, value):
|
def numpy_index_fill(x, dim, index, value):
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
|
||||||
ms_x = Tensor(x)
|
|
||||||
ms_dim = dim
|
|
||||||
ms_index = Tensor(index)
|
|
||||||
ms_value = value
|
|
||||||
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
|
||||||
# PyNative Mode
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
|
||||||
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
|
||||||
|
|
||||||
# Numpy
|
|
||||||
np_result = x.copy()
|
np_result = x.copy()
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
np_result[index] = value
|
np_result[index] = value
|
||||||
|
@ -53,6 +43,23 @@ def compare_with_numpy(x, dim, index, value):
|
||||||
np_result[:, :, index] = value
|
np_result[:, :, index] = value
|
||||||
else:
|
else:
|
||||||
raise ValueError("dim must be 0, 1 or 2")
|
raise ValueError("dim must be 0, 1 or 2")
|
||||||
|
return np_result
|
||||||
|
|
||||||
|
|
||||||
|
def compare_with_numpy(x, dim, index, value):
|
||||||
|
# Graph Mode
|
||||||
|
ms_x = Tensor(x)
|
||||||
|
ms_dim = dim
|
||||||
|
ms_index = Tensor(index)
|
||||||
|
ms_value = value
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
ms_result_graph = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||||
|
# PyNative Mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
ms_result_pynative = IndexFillNet()(ms_x, ms_dim, ms_index, ms_value).asnumpy()
|
||||||
|
|
||||||
|
# Numpy
|
||||||
|
np_result = numpy_index_fill(x, dim, index, value)
|
||||||
|
|
||||||
return np.allclose(ms_result_graph, np_result) and np.allclose(ms_result_pynative, np_result)
|
return np.allclose(ms_result_graph, np_result) and np.allclose(ms_result_pynative, np_result)
|
||||||
|
|
||||||
|
@ -60,7 +67,7 @@ def compare_with_numpy(x, dim, index, value):
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64])
|
@pytest.mark.parametrize('data_type', [np.int8, np.int16, np.int32, np.float16, np.float64])
|
||||||
def test_index_fill_data_type(data_type):
|
def test_index_fill_data_type(data_type):
|
||||||
"""
|
"""
|
||||||
Feature: IndexFill
|
Feature: IndexFill
|
||||||
|
@ -127,3 +134,43 @@ def test_index_fill_error(dim, data_type):
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
IndexFillNet()(ms_x, ms_dim, ms_index, ms_value)
|
IndexFillNet()(ms_x, ms_dim, ms_index, ms_value)
|
||||||
|
|
||||||
|
|
||||||
|
class IndexFillVmapNet(nn.Cell):
|
||||||
|
def __init__(self, net, in_axes):
|
||||||
|
super(IndexFillVmapNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.vmap_index_fill = vmap(self.net, in_axes=in_axes, out_axes=0)
|
||||||
|
|
||||||
|
def construct(self, x, dim, index, value):
|
||||||
|
out = self.vmap_index_fill(x, dim, index, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_index_fill_vmap():
|
||||||
|
"""
|
||||||
|
Feature: IndexFill
|
||||||
|
Description: test cases of vmap for IndexFill operator.
|
||||||
|
Expectation: the result match numpy.
|
||||||
|
"""
|
||||||
|
data_type = np.float32
|
||||||
|
dim_type = np.int32
|
||||||
|
dim = Tensor(np.array([0, 1], dtype=dim_type))
|
||||||
|
value = Tensor(np.array([-20, -10], dtype=data_type))
|
||||||
|
x_np = np.random.random(size=(2, 5, 5)).astype(data_type)
|
||||||
|
index_np = np.random.randint(low=0, high=5, size=(2, 4)).astype(np.int32)
|
||||||
|
|
||||||
|
# MindSpore
|
||||||
|
ms_x = Tensor(x_np)
|
||||||
|
ms_index = Tensor(index_np)
|
||||||
|
ms_result_graph = IndexFillVmapNet(IndexFillNet(), (0, 0, 0, 0))(ms_x, dim, ms_index, value).asnumpy()
|
||||||
|
|
||||||
|
# NumPy
|
||||||
|
np_result = [None, None]
|
||||||
|
for i in range(2):
|
||||||
|
np_result[i] = numpy_index_fill(x_np[i], dim[i], index_np[i], value[i])
|
||||||
|
np_result = np.asarray(np_result)
|
||||||
|
assert np.allclose(ms_result_graph, np_result)
|
||||||
|
|
Loading…
Reference in New Issue