forked from mindspore-Ecosystem/mindspore
fix gpu svd
This commit is contained in:
parent
2a77abfb91
commit
be0facf3cf
|
@ -14,7 +14,7 @@ mindspore.Tensor.svd
|
||||||
返回:
|
返回:
|
||||||
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)`。
|
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)`。
|
||||||
- **u** (Tensor) - 左奇异向量。如果compute_uv为False,该值不会返回。shape为 :math:`(*, M, P)` 。如果full_matrices为true,则shape为 :math:`(*, M, M)` 。
|
- **u** (Tensor) - 左奇异向量。如果compute_uv为False,该值不会返回。shape为 :math:`(*, M, P)` 。如果full_matrices为true,则shape为 :math:`(*, M, M)` 。
|
||||||
- **v** (Tensor) - 右奇异向量。如果compute_uv为False,该值不会返回。shape为 :math:`(*, P, N)` 。如果full_matrices为true,则shape为 :math:`(*, N, N)` 。
|
- **v** (Tensor) - 右奇异向量。如果compute_uv为False,该值不会返回。shape为 :math:`(*, N, P)` 。如果full_matrices为true,则shape为 :math:`(*, N, N)` 。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `full_matrices` 或 `compute_uv` 不是bool类型。
|
- **TypeError** - `full_matrices` 或 `compute_uv` 不是bool类型。
|
||||||
|
|
|
@ -18,7 +18,7 @@ mindspore.ops.svd
|
||||||
返回:
|
返回:
|
||||||
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)` 。
|
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)` 。
|
||||||
- **u** (Tensor) - 左奇异向量。如果 `compute_uv` 为False,该值不会返回。shape为 :math:`(*, M, P)` 。如果 `full_matrices` 为True,则shape为 :math:`(*, M, M)` 。
|
- **u** (Tensor) - 左奇异向量。如果 `compute_uv` 为False,该值不会返回。shape为 :math:`(*, M, P)` 。如果 `full_matrices` 为True,则shape为 :math:`(*, M, M)` 。
|
||||||
- **v** (Tensor) - 右奇异向量。如果 `compute_uv` 为False,该值不会返回。shape为 :math:`(*, P, N)` 。如果 `full_matrices` 为True,则shape为 :math:`(*, N, N)` 。
|
- **v** (Tensor) - 右奇异向量。如果 `compute_uv` 为False,该值不会返回。shape为 :math:`(*, N, P)` 。如果 `full_matrices` 为True,则shape为 :math:`(*, N, N)` 。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `full_matrices` 或 `compute_uv` 不是bool类型。
|
- **TypeError** - `full_matrices` 或 `compute_uv` 不是bool类型。
|
||||||
|
|
|
@ -31,8 +31,7 @@ bool SvdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
launch_kernel_func_ = func_list_[index].second.first;
|
launch_kernel_func_ = func_list_[index].second;
|
||||||
init_size_lists_func_ = func_list_[index].second.second;
|
|
||||||
compute_uv_ = kernel_ptr->compute_uv();
|
compute_uv_ = kernel_ptr->compute_uv();
|
||||||
full_matrices_ = kernel_ptr->full_matrices();
|
full_matrices_ = kernel_ptr->full_matrices();
|
||||||
job_ = compute_uv_ ? (full_matrices_ ? 'A' : 'S') : 'N';
|
job_ = compute_uv_ ? (full_matrices_ ? 'A' : 'S') : 'N';
|
||||||
|
@ -47,21 +46,11 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||||
if (IsDynamicRank(input_shape)) {
|
|
||||||
return KRET_OK;
|
|
||||||
}
|
|
||||||
DestroyResource();
|
|
||||||
ResetResource();
|
|
||||||
input_shape_ = Convert2SizeTClipNeg(input_shape);
|
input_shape_ = Convert2SizeTClipNeg(input_shape);
|
||||||
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||||
is_null_input_ = (total_size_ == 0);
|
|
||||||
if (is_null_input_) {
|
|
||||||
init_size_lists_func_(this);
|
|
||||||
return KRET_OK;
|
|
||||||
}
|
|
||||||
dims_ = input_shape_.size();
|
dims_ = input_shape_.size();
|
||||||
if (dims_ < kDim2) {
|
if (dims_ < kDim2) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got " << dims_;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_ = input_shape_[dims_ - kDim2];
|
m_ = input_shape_[dims_ - kDim2];
|
||||||
|
@ -74,57 +63,46 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
||||||
for (size_t i = 0; i < dims_ - kDim2; i++) {
|
for (size_t i = 0; i < dims_ - kDim2; i++) {
|
||||||
batch_size_ = batch_size_ * input_shape_[i];
|
batch_size_ = batch_size_ * input_shape_[i];
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < dims_; ++i) {
|
|
||||||
transpose_input_shape_[i] = input_shape_[i];
|
|
||||||
if (i == dims_ - kDim2) {
|
|
||||||
transpose_input_axis_[i] = dims_ - kDim1;
|
|
||||||
} else if (i == dims_ - kDim1) {
|
|
||||||
transpose_input_axis_[i] = dims_ - kDim2;
|
|
||||||
} else {
|
|
||||||
transpose_input_axis_[i] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
init_size_lists_func_(this);
|
|
||||||
constexpr auto kBatchedMaxRowCol = 32;
|
constexpr auto kBatchedMaxRowCol = 32;
|
||||||
if (m_ <= kBatchedMaxRowCol && n_ <= kBatchedMaxRowCol && batch_size_ > 1 && (full_matrices_ || m_ == n_)) {
|
if (m_ <= kBatchedMaxRowCol && n_ <= kBatchedMaxRowCol && batch_size_ > 1 && (full_matrices_ || m_ == n_)) {
|
||||||
batched_ = true;
|
batched_ = true;
|
||||||
}
|
}
|
||||||
|
unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
|
||||||
|
ResetResource();
|
||||||
|
InitSizeLists();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void SvdGpuKernelMod::InitSizeLists() {
|
void SvdGpuKernelMod::InitSizeLists() {
|
||||||
// input a
|
// input a
|
||||||
input_size_list_.push_back(total_size_ * sizeof(T));
|
input_size_list_.push_back(total_size_ * unit_size_);
|
||||||
// output s, u, v
|
// output s, u, v
|
||||||
output_size_list_.push_back(batch_size_ * p_ * sizeof(T));
|
output_size_list_.push_back(batch_size_ * p_ * unit_size_);
|
||||||
if (compute_uv_) {
|
if (compute_uv_) {
|
||||||
if (full_matrices_) {
|
if (full_matrices_) {
|
||||||
output_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
|
output_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
|
||||||
output_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
|
output_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
|
||||||
} else {
|
} else {
|
||||||
output_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
|
output_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
|
||||||
output_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
|
output_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output_size_list_.push_back(0);
|
output_size_list_.push_back(0);
|
||||||
output_size_list_.push_back(0);
|
output_size_list_.push_back(0);
|
||||||
}
|
}
|
||||||
workspace_size_list_.push_back(batch_size_ * sizeof(int)); // for dev_info
|
// for dev_info
|
||||||
|
workspace_size_list_.push_back(batch_size_ * sizeof(int));
|
||||||
// for transpose input
|
// for transpose input
|
||||||
workspace_size_list_.push_back(dims_ * sizeof(size_t));
|
workspace_size_list_.push_back(total_size_ * unit_size_);
|
||||||
workspace_size_list_.push_back(dims_ * sizeof(size_t));
|
|
||||||
workspace_size_list_.push_back(total_size_ * sizeof(T));
|
|
||||||
|
|
||||||
// for dev_u and dev_v
|
// for dev_u and dev_v
|
||||||
if (compute_uv_) {
|
if (compute_uv_ || batched_) {
|
||||||
if (full_matrices_) {
|
if (full_matrices_ || batched_) {
|
||||||
workspace_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
|
workspace_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
|
||||||
workspace_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
|
workspace_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
|
||||||
} else {
|
} else {
|
||||||
workspace_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
|
workspace_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
|
||||||
workspace_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
|
workspace_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -133,7 +111,6 @@ template <typename T>
|
||||||
void SvdGpuKernelMod::RunSvd(const size_t m, const size_t n, const size_t batch, T *d_a, int *dev_info, T *output_s,
|
void SvdGpuKernelMod::RunSvd(const size_t m, const size_t n, const size_t batch, T *d_a, int *dev_info, T *output_s,
|
||||||
T *d_output_u, T *d_output_v) {
|
T *d_output_u, T *d_output_v) {
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
|
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSgesvd_bufferSize(handle_, m, n, &lwork),
|
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSgesvd_bufferSize(handle_, m, n, &lwork),
|
||||||
"cusolver query svd work size fail");
|
"cusolver query svd work size fail");
|
||||||
|
@ -208,59 +185,35 @@ void SvdGpuKernelMod::RunSvdBatched(const size_t m, const size_t n, T *d_input,
|
||||||
"cusolver svd fail");
|
"cusolver svd fail");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDestroyGesvdjInfo(info), "cusolver svd fail");
|
||||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(work);
|
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(work);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, size_t *d_transpose_input_shape,
|
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, T *d_output_u, T *d_output_v, T *output_u,
|
||||||
size_t *d_transpose_input_axis, T *d_output_u, T *d_output_v, T *output_u,
|
|
||||||
T *output_v) {
|
T *output_v) {
|
||||||
if (full_matrices_) {
|
if (full_matrices_) {
|
||||||
transpose_input_shape_[dims_ - kDim2] = m;
|
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * m), SizeToInt(m), SizeToInt(m), output_u, device_id_,
|
||||||
transpose_input_shape_[dims_ - kDim1] = m;
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
CalTranspose(batch_size_ * m * m, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
||||||
if (batched_) {
|
if (batched_) {
|
||||||
transpose_input_shape_[dims_ - kDim2] = n;
|
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * n), SizeToInt(n), SizeToInt(n), output_v, device_id_,
|
||||||
transpose_input_shape_[dims_ - kDim1] = n;
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
CalTranspose(batch_size_ * n * n, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
||||||
} else {
|
} else {
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * n, cudaMemcpyHostToDevice,
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * n, cudaMemcpyDeviceToDevice,
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||||
"cuda memcpy failed!");
|
"cuda memcpy failed!");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transpose_input_shape_[dims_ - kDim2] = p_;
|
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * p_), SizeToInt(p_), SizeToInt(m), output_u, device_id_,
|
||||||
transpose_input_shape_[dims_ - kDim1] = m;
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
CalTranspose(batch_size_ * m * p_, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
||||||
|
|
||||||
if (batched_) {
|
if (batched_) {
|
||||||
transpose_input_shape_[dims_ - kDim2] = p_;
|
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * p_), SizeToInt(p_), SizeToInt(n), output_v, device_id_,
|
||||||
transpose_input_shape_[dims_ - kDim1] = n;
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
CalTranspose(batch_size_ * n * p_, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
||||||
} else {
|
} else {
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * p_, cudaMemcpyHostToDevice,
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * p_, cudaMemcpyDeviceToDevice,
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||||
"cuda memcpy failed!");
|
"cuda memcpy failed!");
|
||||||
}
|
}
|
||||||
|
@ -286,8 +239,7 @@ void SvdGpuKernelMod::CheckResult(int *dev_info) {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v,
|
void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v,
|
||||||
T *d_output_u, T *d_output_v, int *dev_info, size_t *d_transpose_input_shape,
|
T *d_output_u, T *d_output_v, int *dev_info) {
|
||||||
size_t *d_transpose_input_axis) {
|
|
||||||
if (batched_) {
|
if (batched_) {
|
||||||
RunSvdBatched(m, n, d_input, output_s, d_output_u, d_output_v, dev_info);
|
RunSvdBatched(m, n, d_input, output_s, d_output_u, d_output_v, dev_info);
|
||||||
} else {
|
} else {
|
||||||
|
@ -297,7 +249,7 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
|
||||||
}
|
}
|
||||||
|
|
||||||
if (compute_uv_) {
|
if (compute_uv_) {
|
||||||
TransposeUV(m, n, d_transpose_input_shape, d_transpose_input_axis, d_output_u, d_output_v, output_u, output_v);
|
TransposeUV(m, n, d_output_u, d_output_v, output_u, output_v);
|
||||||
}
|
}
|
||||||
|
|
||||||
CheckResult(dev_info);
|
CheckResult(dev_info);
|
||||||
|
@ -306,9 +258,6 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
if (is_null_input_) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||||
"CusolverDnSetStream failed");
|
"CusolverDnSetStream failed");
|
||||||
|
|
||||||
|
@ -318,67 +267,55 @@ bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
||||||
T *output_v = nullptr;
|
T *output_v = nullptr;
|
||||||
T *d_output_u = nullptr;
|
T *d_output_u = nullptr;
|
||||||
T *d_output_v = nullptr;
|
T *d_output_v = nullptr;
|
||||||
if (compute_uv_) {
|
if (compute_uv_ || batched_) {
|
||||||
output_u = GetDeviceAddress<T>(outputs, kIndex1);
|
if (compute_uv_) {
|
||||||
output_v = GetDeviceAddress<T>(outputs, kIndex2);
|
output_u = GetDeviceAddress<T>(outputs, kIndex1);
|
||||||
|
output_v = GetDeviceAddress<T>(outputs, kIndex2);
|
||||||
|
}
|
||||||
// Store output u and v before transpose.
|
// Store output u and v before transpose.
|
||||||
d_output_u = GetDeviceAddress<T>(workspace, kIndex4);
|
d_output_u = GetDeviceAddress<T>(workspace, kIndex2);
|
||||||
d_output_v = GetDeviceAddress<T>(workspace, kIndex5);
|
d_output_v = GetDeviceAddress<T>(workspace, kIndex3);
|
||||||
}
|
}
|
||||||
|
|
||||||
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
|
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
|
||||||
|
|
||||||
size_t *d_transpose_input_shape = GetDeviceAddress<size_t>(workspace, kIndex1);
|
T *d_input = GetDeviceAddress<T>(workspace, kIndex1);
|
||||||
size_t *d_transpose_input_axis = GetDeviceAddress<size_t>(workspace, kIndex2);
|
|
||||||
T *d_input = GetDeviceAddress<T>(workspace, kIndex3);
|
|
||||||
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(d_transpose_input_axis, transpose_input_axis_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
|
|
||||||
if (m_ge_n_) {
|
if (m_ge_n_) {
|
||||||
// Because cudaSovler expects column-major matrix, we need transpose A.
|
// Because cudaSovler expects column-major matrix, we need transpose A.
|
||||||
|
MatrixTranspose(input, SizeToInt(total_size_), SizeToInt(m_), SizeToInt(n_), d_input, device_id_,
|
||||||
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||||
|
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info);
|
||||||
|
} else {
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyDeviceToDevice,
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||||
"cuda memcpy failed!");
|
"cuda memcpy failed!");
|
||||||
CalTranspose(total_size_, input, d_transpose_input_shape, d_transpose_input_axis, dims_, d_input,
|
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info);
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
||||||
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info, d_transpose_input_shape,
|
|
||||||
d_transpose_input_axis);
|
|
||||||
} else {
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyHostToDevice,
|
|
||||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
||||||
"cuda memcpy failed!");
|
|
||||||
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info, d_transpose_input_shape,
|
|
||||||
d_transpose_input_axis);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<KernelAttr, std::pair<SvdGpuKernelMod::LaunchKernelFunc, SvdGpuKernelMod::InitSizeListsFunc>>>
|
std::vector<std::pair<KernelAttr, SvdGpuKernelMod::LaunchKernelFunc>> SvdGpuKernelMod::func_list_ = {
|
||||||
SvdGpuKernelMod::func_list_ = {
|
{KernelAttr()
|
||||||
{KernelAttr()
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
&SvdGpuKernelMod::LaunchKernel<float>},
|
||||||
{&SvdGpuKernelMod::LaunchKernel<float>, &SvdGpuKernelMod::InitSizeLists<float>}},
|
{KernelAttr()
|
||||||
{KernelAttr()
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
.AddOutputAttr(kNumberTypeFloat64)
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
.AddOutputAttr(kNumberTypeFloat64)
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
&SvdGpuKernelMod::LaunchKernel<double>},
|
||||||
{&SvdGpuKernelMod::LaunchKernel<double>, &SvdGpuKernelMod::InitSizeLists<double>}},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<KernelAttr> SvdGpuKernelMod::GetOpSupport() {
|
std::vector<KernelAttr> SvdGpuKernelMod::GetOpSupport() {
|
||||||
std::vector<KernelAttr> support_list;
|
std::vector<KernelAttr> support_list;
|
||||||
(void)std::transform(
|
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
[](const std::pair<KernelAttr, LaunchKernelFunc> &pair) { return pair.first; });
|
||||||
[](const std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>> &pair) { return pair.first; });
|
|
||||||
return support_list;
|
return support_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -46,16 +47,11 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
if (is_null_input_) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
cuda_stream_ = stream_ptr;
|
cuda_stream_ = stream_ptr;
|
||||||
return launch_kernel_func_(this, inputs, workspace, outputs);
|
return launch_kernel_func_(this, inputs, workspace, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ResetResource() noexcept {
|
void ResetResource() noexcept {
|
||||||
is_null_input_ = false;
|
|
||||||
|
|
||||||
input_size_list_.clear();
|
input_size_list_.clear();
|
||||||
output_size_list_.clear();
|
output_size_list_.clear();
|
||||||
workspace_size_list_.clear();
|
workspace_size_list_.clear();
|
||||||
|
@ -64,7 +60,6 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
std::vector<KernelAttr> GetOpSupport() override;
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <typename T>
|
|
||||||
void InitSizeLists();
|
void InitSizeLists();
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -76,11 +71,10 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RunSvdBatched(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, int *dev_info);
|
void RunSvdBatched(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, int *dev_info);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void TransposeUV(const size_t m, const size_t n, size_t *d_transpose_input_shape, size_t *d_transpose_input_axis,
|
void TransposeUV(const size_t m, const size_t n, T *d_output_u, T *d_output_v, T *output_u, T *output_v);
|
||||||
T *d_output_u, T *d_output_v, T *output_u, T *output_v);
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, T *d_output_u,
|
void LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, T *d_output_u,
|
||||||
T *d_output_v, int *dev_info, size_t *d_transpose_input_shape, size_t *d_transpose_input_axis);
|
T *d_output_v, int *dev_info);
|
||||||
void CheckResult(int *dev_info);
|
void CheckResult(int *dev_info);
|
||||||
|
|
||||||
using LaunchKernelFunc =
|
using LaunchKernelFunc =
|
||||||
|
@ -88,9 +82,9 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||||
using InitSizeListsFunc = std::function<void(SvdGpuKernelMod *)>;
|
using InitSizeListsFunc = std::function<void(SvdGpuKernelMod *)>;
|
||||||
LaunchKernelFunc launch_kernel_func_{nullptr};
|
LaunchKernelFunc launch_kernel_func_{nullptr};
|
||||||
InitSizeListsFunc init_size_lists_func_{nullptr};
|
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> func_list_;
|
||||||
static std::vector<std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>>> func_list_;
|
|
||||||
|
|
||||||
|
size_t unit_size_{1};
|
||||||
bool compute_uv_{false};
|
bool compute_uv_{false};
|
||||||
bool full_matrices_{false};
|
bool full_matrices_{false};
|
||||||
std::vector<size_t> input_shape_;
|
std::vector<size_t> input_shape_;
|
||||||
|
@ -104,9 +98,6 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
|
||||||
bool m_ge_n_{false};
|
bool m_ge_n_{false};
|
||||||
bool batched_{false};
|
bool batched_{false};
|
||||||
|
|
||||||
size_t transpose_input_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
|
|
||||||
size_t transpose_input_axis_[TRANSPOSE_MAX_DIMENSION] = {0};
|
|
||||||
bool is_null_input_;
|
|
||||||
cusolverDnHandle_t handle_{nullptr};
|
cusolverDnHandle_t handle_{nullptr};
|
||||||
void *cuda_stream_{nullptr};
|
void *cuda_stream_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
|
@ -4750,7 +4750,7 @@ class Tensor(Tensor_):
|
||||||
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will not be returned.
|
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will not be returned.
|
||||||
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
|
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
|
||||||
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will not be returned.
|
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will not be returned.
|
||||||
The shape is :math:`(*, P, N)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
|
The shape is :math:`(*, N, P)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If full_matrices or compute_uv is not the type of bool.
|
TypeError: If full_matrices or compute_uv is not the type of bool.
|
||||||
|
|
|
@ -42,7 +42,7 @@ def svd(a, full_matrices=False, compute_uv=True):
|
||||||
- **u** (Tensor) - Left singular vectors. If `compute_uv` is False, u will not be returned.
|
- **u** (Tensor) - Left singular vectors. If `compute_uv` is False, u will not be returned.
|
||||||
The shape is :math:`(*, M, P)`. If `full_matrices` is True, the shape will be :math:`(*, M, M)`.
|
The shape is :math:`(*, M, P)`. If `full_matrices` is True, the shape will be :math:`(*, M, M)`.
|
||||||
- **v** (Tensor) - Right singular vectors. If `compute_uv` is False, v will not be returned.
|
- **v** (Tensor) - Right singular vectors. If `compute_uv` is False, v will not be returned.
|
||||||
The shape is :math:`(*, P, N)`. If `full_matrices` is True, the shape will be :math:`(*, N, N)`.
|
The shape is :math:`(*, N, P)`. If `full_matrices` is True, the shape will be :math:`(*, N, N)`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `full_matrices` or `compute_uv` is not the type of bool.
|
TypeError: If `full_matrices` or `compute_uv` is not the type of bool.
|
||||||
|
|
|
@ -82,7 +82,7 @@ class Svd(Primitive):
|
||||||
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will be an empty tensor.
|
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will be an empty tensor.
|
||||||
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
|
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
|
||||||
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will be an empty tensor.
|
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will be an empty tensor.
|
||||||
The shape is :math:`(*, P, N)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
|
The shape is :math:`(*, N, P)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If full_matrices or compute_uv is not the type of bool.
|
TypeError: If full_matrices or compute_uv is not the type of bool.
|
||||||
|
|
Loading…
Reference in New Issue