!36404 Opt MatrixDiagV3 operator in GPU platform.
Merge pull request !36404 from hezhenhao1/add_matrix_diag
This commit is contained in:
commit
09e010a0d5
|
@ -29,7 +29,6 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr int kMatrixDiagV3InputsNum = 5;
|
||||
constexpr int kMatrixDiagV3OutputsNum = 1;
|
||||
constexpr int kMatrixDiagV3MinOutputShape = 2;
|
||||
} // namespace
|
||||
|
||||
bool MatrixDiagV3GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -54,64 +53,60 @@ bool MatrixDiagV3GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
int MatrixDiagV3GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret;
|
||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != KRET_OK) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto x_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
x_size_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{});
|
||||
if (x_size_ == 0) {
|
||||
return ret;
|
||||
if (x_shape.size() < kDim1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', resize failed, some undefined behaviors happened.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
max_diag_len_ = x_shape.back();
|
||||
auto k_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
k_size_ = std::accumulate(k_shape.begin(), k_shape.end(), 1, std::multiplies{});
|
||||
k_size_ = std::accumulate(k_shape.begin(), k_shape.end(), int64_t(1), std::multiplies{});
|
||||
y_shape_ = outputs.at(kIndex0)->GetShapeVector();
|
||||
y_size_ = std::accumulate(y_shape_.begin(), y_shape_.end(), 1, std::multiplies{});
|
||||
if (y_shape_.size() < kMatrixDiagV3MinOutputShape) {
|
||||
y_size_ = std::accumulate(y_shape_.begin(), y_shape_.end(), int64_t(1), std::multiplies{});
|
||||
if (y_shape_.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', resize failed, some undefined behaviors happened.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
num_cols_ = y_shape_.at(y_shape_.size() - kIndex1);
|
||||
num_rows_ = y_shape_.at(y_shape_.size() - kIndex2);
|
||||
return ret;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool MatrixDiagV3GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (x_size_ == 0) {
|
||||
if (x_size_ == 0 || y_size_ == 0) {
|
||||
return true;
|
||||
}
|
||||
if (!IsValidShape(y_shape_)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'y' is invalid, since all the inputs are not ready.";
|
||||
return false;
|
||||
}
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto x_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
|
||||
auto k_ptr = GetDeviceAddress<kIntType>(inputs, kIndex1);
|
||||
auto padding_value_ptr = GetDeviceAddress<DataType>(inputs, kIndex4);
|
||||
auto y_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
||||
bool is_nullptr = (x_ptr == nullptr) || (k_ptr == nullptr) || (padding_value_ptr == nullptr) || (y_ptr == nullptr);
|
||||
if (is_nullptr) {
|
||||
auto any = [](auto &&... args) -> bool { return ((args == nullptr) || ...); };
|
||||
if (any(cuda_stream, x_ptr, k_ptr, padding_value_ptr, y_ptr)) {
|
||||
return false;
|
||||
}
|
||||
// Get 'k' and store as [lower_diag_index, upper_diag_index].
|
||||
kIntType k_stand;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&k_stand, k_ptr, sizeof(kIntType), cudaMemcpyDeviceToHost),
|
||||
"cudaMemcpy input 'k' to host failed.");
|
||||
"In MatrixDiagV3 kernel, cudaMemcpy input 'k' to host failed.");
|
||||
int64_t upper_diag_index, lower_diag_index = IntToLong(k_stand);
|
||||
if (k_size_ == 1) {
|
||||
upper_diag_index = lower_diag_index;
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpy(&k_stand, k_ptr + 1, sizeof(kIntType), cudaMemcpyDeviceToHost),
|
||||
"cudaMemcpy input 'k' to host failed.");
|
||||
"In MatrixDiagV3 kernel, cudaMemcpy input 'k' to host failed.");
|
||||
upper_diag_index = IntToLong(k_stand);
|
||||
}
|
||||
|
||||
MatrixDiagV3(x_ptr, padding_value_ptr, y_ptr, y_size_, num_rows_, num_cols_, lower_diag_index, upper_diag_index,
|
||||
max_diag_len_, left_align_super_diag_, left_align_sub_diag_, cuda_stream);
|
||||
max_diag_len_, left_align_super_diag_, left_align_sub_diag_, device_id_, cuda_stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,112 +17,118 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_diag_v3_impl.cuh"
|
||||
#include <algorithm>
|
||||
|
||||
__device__ inline int64_t ComputeOffset(const int64_t diag_idx, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag) {
|
||||
__device__ inline int ComputeOffset(int diag_idx, int num_rows, int num_cols, int max_diag_len,
|
||||
bool left_align_super_diag, bool left_align_sub_diag) {
|
||||
bool left_align = (diag_idx >= 0 && left_align_super_diag) || (diag_idx <= 0 && left_align_sub_diag);
|
||||
if (left_align) {
|
||||
return 0;
|
||||
}
|
||||
int64_t diag_len1 = num_cols - max(diag_idx, int64_t(0));
|
||||
int64_t diag_len2 = num_rows + min(diag_idx, int64_t(0));
|
||||
int diag_len1 = num_cols - max(diag_idx, 0);
|
||||
int diag_len2 = num_rows + min(diag_idx, 0);
|
||||
return max_diag_len - min(diag_len1, diag_len2);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
__global__ void MatrixDiagV3Kernel(const DataType *x_ptr, const DataType *padding_value_ptr, DataType *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t diag_batch_len, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag) {
|
||||
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < y_size; idx += blockDim.x * gridDim.x) {
|
||||
int64_t batch_row_idx = idx / num_cols;
|
||||
int64_t col_idx = idx - batch_row_idx * num_cols;
|
||||
int64_t batch_idx = batch_row_idx / num_rows;
|
||||
int64_t row_idx = batch_row_idx - batch_idx * num_rows;
|
||||
int64_t diag_idx = col_idx - row_idx;
|
||||
int y_size, int num_rows, int num_cols, int lower_diag_index, int upper_diag_index,
|
||||
int diag_batch_len, int max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag) {
|
||||
DataType padding_value = *padding_value_ptr;
|
||||
int start_idx = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int step = static_cast<int>(blockDim.x * gridDim.x);
|
||||
for (int idx = start_idx; idx < y_size; idx += step) {
|
||||
int batch_row_idx = idx / num_cols;
|
||||
int col_idx = idx - batch_row_idx * num_cols;
|
||||
int batch_idx = batch_row_idx / num_rows;
|
||||
int row_idx = batch_row_idx - batch_idx * num_rows;
|
||||
int diag_idx = col_idx - row_idx;
|
||||
if (lower_diag_index <= diag_idx && diag_idx <= upper_diag_index) {
|
||||
int64_t offset =
|
||||
int offset =
|
||||
ComputeOffset(diag_idx, num_rows, num_cols, max_diag_len, left_align_super_diag, left_align_sub_diag);
|
||||
int64_t diag_row_idx = upper_diag_index - diag_idx;
|
||||
int64_t diag_col_idx = col_idx - max(diag_idx, int64_t(0)) + offset;
|
||||
int diag_row_idx = upper_diag_index - diag_idx;
|
||||
int diag_col_idx = col_idx - max(diag_idx, 0) + offset;
|
||||
y_ptr[idx] = x_ptr[batch_idx * diag_batch_len + diag_row_idx * max_diag_len + diag_col_idx];
|
||||
} else {
|
||||
y_ptr[idx] = *padding_value_ptr;
|
||||
y_ptr[idx] = padding_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
void MatrixDiagV3(const DataType *x_ptr, const DataType *padding_value_ptr, DataType *y_ptr, const int64_t y_size,
|
||||
const int64_t num_rows, const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag, cudaStream_t cuda_stream) {
|
||||
void MatrixDiagV3(const DataType *x_ptr, const DataType *padding_value_ptr, DataType *y_ptr, int64_t y_size,
|
||||
int64_t num_rows, int64_t num_cols, int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag, bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
int64_t diag_batch_len = (upper_diag_index - lower_diag_index + 1) * max_diag_len;
|
||||
MatrixDiagV3Kernel<<<GET_BLOCKS(y_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
x_ptr, padding_value_ptr, y_ptr, y_size, num_rows, num_cols, lower_diag_index, upper_diag_index, diag_batch_len,
|
||||
max_diag_len, left_align_super_diag, left_align_sub_diag);
|
||||
MatrixDiagV3Kernel<<<CUDA_BLOCKS(device_id, y_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
x_ptr, padding_value_ptr, y_ptr, static_cast<int>(y_size), static_cast<int>(num_rows), static_cast<int>(num_cols),
|
||||
static_cast<int>(lower_diag_index), static_cast<int>(upper_diag_index), static_cast<int>(diag_batch_len),
|
||||
static_cast<int>(max_diag_len), left_align_super_diag, left_align_sub_diag);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<int8_t>(const int8_t *x_ptr, const int8_t *padding_value_ptr, int8_t *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag, cudaStream_t cuda_stream);
|
||||
int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<int16_t>(const int16_t *x_ptr, const int16_t *padding_value_ptr,
|
||||
int16_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
int16_t *y_ptr, int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<int32_t>(const int32_t *x_ptr, const int32_t *padding_value_ptr,
|
||||
int32_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
int32_t *y_ptr, int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<int64_t>(const int64_t *x_ptr, const int64_t *padding_value_ptr,
|
||||
int64_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
int64_t *y_ptr, int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<uint8_t>(const uint8_t *x_ptr, const uint8_t *padding_value_ptr,
|
||||
uint8_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
uint8_t *y_ptr, int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<uint16_t>(const uint16_t *x_ptr, const uint16_t *padding_value_ptr,
|
||||
uint16_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
cudaStream_t cuda_stream);
|
||||
uint16_t *y_ptr, int64_t y_size, int64_t num_rows,
|
||||
int64_t num_cols, int64_t lower_diag_index,
|
||||
int64_t upper_diag_index, int64_t max_diag_len,
|
||||
bool left_align_super_diag, bool left_align_sub_diag,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<uint32_t>(const uint32_t *x_ptr, const uint32_t *padding_value_ptr,
|
||||
uint32_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
cudaStream_t cuda_stream);
|
||||
uint32_t *y_ptr, int64_t y_size, int64_t num_rows,
|
||||
int64_t num_cols, int64_t lower_diag_index,
|
||||
int64_t upper_diag_index, int64_t max_diag_len,
|
||||
bool left_align_super_diag, bool left_align_sub_diag,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<uint64_t>(const uint64_t *x_ptr, const uint64_t *padding_value_ptr,
|
||||
uint64_t *y_ptr, const int64_t y_size, const int64_t num_rows,
|
||||
const int64_t num_cols, const int64_t lower_diag_index,
|
||||
const int64_t upper_diag_index, const int64_t max_diag_len,
|
||||
const bool left_align_super_diag, const bool left_align_sub_diag,
|
||||
cudaStream_t cuda_stream);
|
||||
uint64_t *y_ptr, int64_t y_size, int64_t num_rows,
|
||||
int64_t num_cols, int64_t lower_diag_index,
|
||||
int64_t upper_diag_index, int64_t max_diag_len,
|
||||
bool left_align_super_diag, bool left_align_sub_diag,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<half>(const half *x_ptr, const half *padding_value_ptr, half *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag, cudaStream_t cuda_stream);
|
||||
int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<float>(const float *x_ptr, const float *padding_value_ptr, float *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag, cudaStream_t cuda_stream);
|
||||
int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixDiagV3<double>(const double *x_ptr, const double *padding_value_ptr, double *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag,
|
||||
const bool left_align_sub_diag, cudaStream_t cuda_stream);
|
||||
int64_t y_size, int64_t num_rows, int64_t num_cols,
|
||||
int64_t lower_diag_index, int64_t upper_diag_index,
|
||||
int64_t max_diag_len, bool left_align_super_diag,
|
||||
bool left_align_sub_diag, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -21,9 +21,8 @@
|
|||
|
||||
template <typename DataType>
|
||||
CUDA_LIB_EXPORT void MatrixDiagV3(const DataType *x_ptr, const DataType *padding_value_ptr, DataType *y_ptr,
|
||||
const int64_t y_size, const int64_t num_rows, const int64_t num_cols,
|
||||
const int64_t lower_diag_index, const int64_t upper_diag_index,
|
||||
const int64_t max_diag_len, const bool left_align_super_diag_,
|
||||
const bool left_align_sub_diag_, cudaStream_t cuda_stream);
|
||||
int64_t y_size, int64_t num_rows, int64_t num_cols, int64_t lower_diag_index,
|
||||
int64_t upper_diag_index, int64_t max_diag_len, bool left_align_super_diag_,
|
||||
bool left_align_sub_diag_, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_DIAG_V3_IMPL_CUH_
|
||||
|
|
Loading…
Reference in New Issue