fix gpu svd

This commit is contained in:
huanghui 2022-10-25 10:43:32 +08:00
parent 51dfb054d0
commit 6d35c8b3ca
7 changed files with 73 additions and 145 deletions

View File

@ -14,7 +14,7 @@ mindspore.Tensor.svd
返回:
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)`
- **u** (Tensor) - 左奇异向量。如果compute_uv为False该值不会返回。shape为 :math:`(*, M, P)` 。如果full_matrices为true则shape为 :math:`(*, M, M)`
- **v** (Tensor) - 右奇异向量。如果compute_uv为False该值不会返回。shape为 :math:`(*, P, N)` 。如果full_matrices为true则shape为 :math:`(*, N, N)`
- **v** (Tensor) - 右奇异向量。如果compute_uv为False该值不会返回。shape为 :math:`(*, N, P)` 。如果full_matrices为true则shape为 :math:`(*, N, N)`
异常:
- **TypeError** - `full_matrices``compute_uv` 不是bool类型。

View File

@ -18,7 +18,7 @@ mindspore.ops.svd
返回:
- **s** (Tensor) - 奇异值。shape为 :math:`(*, P)`
- **u** (Tensor) - 左奇异向量。如果 `compute_uv` 为False该值不会返回。shape为 :math:`(*, M, P)` 。如果 `full_matrices` 为True则shape为 :math:`(*, M, M)`
- **v** (Tensor) - 右奇异向量。如果 `compute_uv` 为False该值不会返回。shape为 :math:`(*, P, N)` 。如果 `full_matrices` 为True则shape为 :math:`(*, N, N)`
- **v** (Tensor) - 右奇异向量。如果 `compute_uv` 为False该值不会返回。shape为 :math:`(*, N, P)` 。如果 `full_matrices` 为True则shape为 :math:`(*, N, N)`
异常:
- **TypeError** - `full_matrices``compute_uv` 不是bool类型。

View File

@ -31,8 +31,7 @@ bool SvdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
launch_kernel_func_ = func_list_[index].second.first;
init_size_lists_func_ = func_list_[index].second.second;
launch_kernel_func_ = func_list_[index].second;
compute_uv_ = kernel_ptr->compute_uv();
full_matrices_ = kernel_ptr->full_matrices();
job_ = compute_uv_ ? (full_matrices_ ? 'A' : 'S') : 'N';
@ -47,21 +46,11 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
return ret;
}
auto input_shape = inputs[kIndex0]->GetShapeVector();
if (IsDynamicRank(input_shape)) {
return KRET_OK;
}
DestroyResource();
ResetResource();
input_shape_ = Convert2SizeTClipNeg(input_shape);
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
is_null_input_ = (total_size_ == 0);
if (is_null_input_) {
init_size_lists_func_(this);
return KRET_OK;
}
dims_ = input_shape_.size();
if (dims_ < kDim2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got " << dims_;
}
m_ = input_shape_[dims_ - kDim2];
@ -74,57 +63,46 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
for (size_t i = 0; i < dims_ - kDim2; i++) {
batch_size_ = batch_size_ * input_shape_[i];
}
for (size_t i = 0; i < dims_; ++i) {
transpose_input_shape_[i] = input_shape_[i];
if (i == dims_ - kDim2) {
transpose_input_axis_[i] = dims_ - kDim1;
} else if (i == dims_ - kDim1) {
transpose_input_axis_[i] = dims_ - kDim2;
} else {
transpose_input_axis_[i] = i;
}
}
init_size_lists_func_(this);
constexpr auto kBatchedMaxRowCol = 32;
if (m_ <= kBatchedMaxRowCol && n_ <= kBatchedMaxRowCol && batch_size_ > 1 && (full_matrices_ || m_ == n_)) {
batched_ = true;
}
unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
ResetResource();
InitSizeLists();
return 0;
}
template <typename T>
void SvdGpuKernelMod::InitSizeLists() {
// input a
input_size_list_.push_back(total_size_ * sizeof(T));
input_size_list_.push_back(total_size_ * unit_size_);
// output s, u, v
output_size_list_.push_back(batch_size_ * p_ * sizeof(T));
output_size_list_.push_back(batch_size_ * p_ * unit_size_);
if (compute_uv_) {
if (full_matrices_) {
output_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
output_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
output_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
output_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
} else {
output_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
output_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
output_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
output_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
}
} else {
output_size_list_.push_back(0);
output_size_list_.push_back(0);
}
workspace_size_list_.push_back(batch_size_ * sizeof(int)); // for dev_info
// for dev_info
workspace_size_list_.push_back(batch_size_ * sizeof(int));
// for transpose input
workspace_size_list_.push_back(dims_ * sizeof(size_t));
workspace_size_list_.push_back(dims_ * sizeof(size_t));
workspace_size_list_.push_back(total_size_ * sizeof(T));
workspace_size_list_.push_back(total_size_ * unit_size_);
// for dev_u and dev_v
if (compute_uv_) {
if (full_matrices_) {
workspace_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
if (compute_uv_ || batched_) {
if (full_matrices_ || batched_) {
workspace_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
workspace_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
} else {
workspace_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
workspace_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
}
}
}
@ -133,7 +111,6 @@ template <typename T>
void SvdGpuKernelMod::RunSvd(const size_t m, const size_t n, const size_t batch, T *d_a, int *dev_info, T *output_s,
T *d_output_u, T *d_output_v) {
int lwork = 0;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSgesvd_bufferSize(handle_, m, n, &lwork),
"cusolver query svd work size fail");
@ -208,59 +185,35 @@ void SvdGpuKernelMod::RunSvdBatched(const size_t m, const size_t n, T *d_input,
"cusolver svd fail");
}
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDestroyGesvdjInfo(info), "cusolver svd fail");
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(work);
}
template <typename T>
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, size_t *d_transpose_input_shape,
size_t *d_transpose_input_axis, T *d_output_u, T *d_output_v, T *output_u,
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, T *d_output_u, T *d_output_v, T *output_u,
T *output_v) {
if (full_matrices_) {
transpose_input_shape_[dims_ - kDim2] = m;
transpose_input_shape_[dims_ - kDim1] = m;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * m * m, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
reinterpret_cast<cudaStream_t>(cuda_stream_));
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * m), SizeToInt(m), SizeToInt(m), output_u, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
if (batched_) {
transpose_input_shape_[dims_ - kDim2] = n;
transpose_input_shape_[dims_ - kDim1] = n;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * n * n, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
reinterpret_cast<cudaStream_t>(cuda_stream_));
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * n), SizeToInt(n), SizeToInt(n), output_v, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
} else {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * n, cudaMemcpyHostToDevice,
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * n, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
}
} else {
transpose_input_shape_[dims_ - kDim2] = p_;
transpose_input_shape_[dims_ - kDim1] = m;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * m * p_, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
reinterpret_cast<cudaStream_t>(cuda_stream_));
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * p_), SizeToInt(p_), SizeToInt(m), output_u, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
if (batched_) {
transpose_input_shape_[dims_ - kDim2] = p_;
transpose_input_shape_[dims_ - kDim1] = n;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * n * p_, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
reinterpret_cast<cudaStream_t>(cuda_stream_));
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * p_), SizeToInt(p_), SizeToInt(n), output_v, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
} else {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * p_, cudaMemcpyHostToDevice,
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * p_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
}
@ -286,8 +239,7 @@ void SvdGpuKernelMod::CheckResult(int *dev_info) {
template <typename T>
void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v,
T *d_output_u, T *d_output_v, int *dev_info, size_t *d_transpose_input_shape,
size_t *d_transpose_input_axis) {
T *d_output_u, T *d_output_v, int *dev_info) {
if (batched_) {
RunSvdBatched(m, n, d_input, output_s, d_output_u, d_output_v, dev_info);
} else {
@ -297,7 +249,7 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
}
if (compute_uv_) {
TransposeUV(m, n, d_transpose_input_shape, d_transpose_input_axis, d_output_u, d_output_v, output_u, output_v);
TransposeUV(m, n, d_output_u, d_output_v, output_u, output_v);
}
CheckResult(dev_info);
@ -306,9 +258,6 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
template <typename T>
bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (is_null_input_) {
return true;
}
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"CusolverDnSetStream failed");
@ -318,67 +267,55 @@ bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
T *output_v = nullptr;
T *d_output_u = nullptr;
T *d_output_v = nullptr;
if (compute_uv_) {
output_u = GetDeviceAddress<T>(outputs, kIndex1);
output_v = GetDeviceAddress<T>(outputs, kIndex2);
if (compute_uv_ || batched_) {
if (compute_uv_) {
output_u = GetDeviceAddress<T>(outputs, kIndex1);
output_v = GetDeviceAddress<T>(outputs, kIndex2);
}
// Store output u and v before transpose.
d_output_u = GetDeviceAddress<T>(workspace, kIndex4);
d_output_v = GetDeviceAddress<T>(workspace, kIndex5);
d_output_u = GetDeviceAddress<T>(workspace, kIndex2);
d_output_v = GetDeviceAddress<T>(workspace, kIndex3);
}
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
size_t *d_transpose_input_shape = GetDeviceAddress<size_t>(workspace, kIndex1);
size_t *d_transpose_input_axis = GetDeviceAddress<size_t>(workspace, kIndex2);
T *d_input = GetDeviceAddress<T>(workspace, kIndex3);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_axis, transpose_input_axis_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
T *d_input = GetDeviceAddress<T>(workspace, kIndex1);
if (m_ge_n_) {
// Because cudaSovler expects column-major matrix, we need transpose A.
MatrixTranspose(input, SizeToInt(total_size_), SizeToInt(m_), SizeToInt(n_), d_input, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info);
} else {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(total_size_, input, d_transpose_input_shape, d_transpose_input_axis, dims_, d_input,
reinterpret_cast<cudaStream_t>(cuda_stream_));
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info, d_transpose_input_shape,
d_transpose_input_axis);
} else {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info, d_transpose_input_shape,
d_transpose_input_axis);
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info);
}
return true;
}
std::vector<std::pair<KernelAttr, std::pair<SvdGpuKernelMod::LaunchKernelFunc, SvdGpuKernelMod::InitSizeListsFunc>>>
SvdGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
{&SvdGpuKernelMod::LaunchKernel<float>, &SvdGpuKernelMod::InitSizeLists<float>}},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
{&SvdGpuKernelMod::LaunchKernel<double>, &SvdGpuKernelMod::InitSizeLists<double>}},
std::vector<std::pair<KernelAttr, SvdGpuKernelMod::LaunchKernelFunc>> SvdGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SvdGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&SvdGpuKernelMod::LaunchKernel<double>},
};
std::vector<KernelAttr> SvdGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>> &pair) { return pair.first; });
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LaunchKernelFunc> &pair) { return pair.first; });
return support_list;
}

View File

@ -31,6 +31,7 @@
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh"
namespace mindspore {
namespace kernel {
@ -46,16 +47,11 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = stream_ptr;
return launch_kernel_func_(this, inputs, workspace, outputs);
}
void ResetResource() noexcept {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
@ -64,7 +60,6 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
protected:
template <typename T>
void InitSizeLists();
template <typename T>
@ -76,11 +71,10 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
template <typename T>
void RunSvdBatched(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, int *dev_info);
template <typename T>
void TransposeUV(const size_t m, const size_t n, size_t *d_transpose_input_shape, size_t *d_transpose_input_axis,
T *d_output_u, T *d_output_v, T *output_u, T *output_v);
void TransposeUV(const size_t m, const size_t n, T *d_output_u, T *d_output_v, T *output_u, T *output_v);
template <typename T>
void LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v, T *d_output_u,
T *d_output_v, int *dev_info, size_t *d_transpose_input_shape, size_t *d_transpose_input_axis);
T *d_output_v, int *dev_info);
void CheckResult(int *dev_info);
using LaunchKernelFunc =
@ -88,9 +82,9 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
using InitSizeListsFunc = std::function<void(SvdGpuKernelMod *)>;
LaunchKernelFunc launch_kernel_func_{nullptr};
InitSizeListsFunc init_size_lists_func_{nullptr};
static std::vector<std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>>> func_list_;
static std::vector<std::pair<KernelAttr, LaunchKernelFunc>> func_list_;
size_t unit_size_{1};
bool compute_uv_{false};
bool full_matrices_{false};
std::vector<size_t> input_shape_;
@ -104,9 +98,6 @@ class SvdGpuKernelMod : public NativeGpuKernelMod {
bool m_ge_n_{false};
bool batched_{false};
size_t transpose_input_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
size_t transpose_input_axis_[TRANSPOSE_MAX_DIMENSION] = {0};
bool is_null_input_;
cusolverDnHandle_t handle_{nullptr};
void *cuda_stream_{nullptr};
};

View File

@ -5003,7 +5003,7 @@ class Tensor(Tensor_):
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will not be returned.
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will not be returned.
The shape is :math:`(*, P, N)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
The shape is :math:`(*, N, P)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
Raises:
TypeError: If full_matrices or compute_uv is not the type of bool.

View File

@ -42,7 +42,7 @@ def svd(a, full_matrices=False, compute_uv=True):
- **u** (Tensor) - Left singular vectors. If `compute_uv` is False, u will not be returned.
The shape is :math:`(*, M, P)`. If `full_matrices` is True, the shape will be :math:`(*, M, M)`.
- **v** (Tensor) - Right singular vectors. If `compute_uv` is False, v will not be returned.
The shape is :math:`(*, P, N)`. If `full_matrices` is True, the shape will be :math:`(*, N, N)`.
The shape is :math:`(*, N, P)`. If `full_matrices` is True, the shape will be :math:`(*, N, N)`.
Raises:
TypeError: If `full_matrices` or `compute_uv` is not the type of bool.

View File

@ -83,7 +83,7 @@ class Svd(Primitive):
- **u** (Tensor) - Left singular vectors. If compute_uv is False, u will be an empty tensor.
The shape is :math:`(*, M, P)`. If full_matrices is True, the shape will be :math:`(*, M, M)`.
- **v** (Tensor) - Right singular vectors. If compute_uv is False, v will be an empty tensor.
The shape is :math:`(*, P, N)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
The shape is :math:`(*, N, P)`. If full_matrices is True, the shape will be :math:`(*, N, N)`.
Raises:
TypeError: If full_matrices or compute_uv is not the type of bool.