|
|
|
@ -31,8 +31,7 @@ bool SvdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
|
|
|
|
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
launch_kernel_func_ = func_list_[index].second.first;
|
|
|
|
|
init_size_lists_func_ = func_list_[index].second.second;
|
|
|
|
|
launch_kernel_func_ = func_list_[index].second;
|
|
|
|
|
compute_uv_ = kernel_ptr->compute_uv();
|
|
|
|
|
full_matrices_ = kernel_ptr->full_matrices();
|
|
|
|
|
job_ = compute_uv_ ? (full_matrices_ ? 'A' : 'S') : 'N';
|
|
|
|
@ -47,21 +46,11 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
|
|
|
|
if (IsDynamicRank(input_shape)) {
|
|
|
|
|
return KRET_OK;
|
|
|
|
|
}
|
|
|
|
|
DestroyResource();
|
|
|
|
|
ResetResource();
|
|
|
|
|
input_shape_ = Convert2SizeTClipNeg(input_shape);
|
|
|
|
|
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
|
|
|
|
is_null_input_ = (total_size_ == 0);
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
init_size_lists_func_(this);
|
|
|
|
|
return KRET_OK;
|
|
|
|
|
}
|
|
|
|
|
dims_ = input_shape_.size();
|
|
|
|
|
if (dims_ < kDim2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;
|
|
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got " << dims_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m_ = input_shape_[dims_ - kDim2];
|
|
|
|
@ -74,57 +63,46 @@ int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|
|
|
|
for (size_t i = 0; i < dims_ - kDim2; i++) {
|
|
|
|
|
batch_size_ = batch_size_ * input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < dims_; ++i) {
|
|
|
|
|
transpose_input_shape_[i] = input_shape_[i];
|
|
|
|
|
if (i == dims_ - kDim2) {
|
|
|
|
|
transpose_input_axis_[i] = dims_ - kDim1;
|
|
|
|
|
} else if (i == dims_ - kDim1) {
|
|
|
|
|
transpose_input_axis_[i] = dims_ - kDim2;
|
|
|
|
|
} else {
|
|
|
|
|
transpose_input_axis_[i] = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
init_size_lists_func_(this);
|
|
|
|
|
constexpr auto kBatchedMaxRowCol = 32;
|
|
|
|
|
if (m_ <= kBatchedMaxRowCol && n_ <= kBatchedMaxRowCol && batch_size_ > 1 && (full_matrices_ || m_ == n_)) {
|
|
|
|
|
batched_ = true;
|
|
|
|
|
}
|
|
|
|
|
unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
|
|
|
|
|
ResetResource();
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SvdGpuKernelMod::InitSizeLists() {
|
|
|
|
|
// input a
|
|
|
|
|
input_size_list_.push_back(total_size_ * sizeof(T));
|
|
|
|
|
input_size_list_.push_back(total_size_ * unit_size_);
|
|
|
|
|
// output s, u, v
|
|
|
|
|
output_size_list_.push_back(batch_size_ * p_ * sizeof(T));
|
|
|
|
|
output_size_list_.push_back(batch_size_ * p_ * unit_size_);
|
|
|
|
|
if (compute_uv_) {
|
|
|
|
|
if (full_matrices_) {
|
|
|
|
|
output_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
|
|
|
|
|
output_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
|
|
|
|
|
output_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
|
|
|
|
|
output_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
|
|
|
|
|
} else {
|
|
|
|
|
output_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
|
|
|
|
|
output_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
|
|
|
|
|
output_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
|
|
|
|
|
output_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
output_size_list_.push_back(0);
|
|
|
|
|
output_size_list_.push_back(0);
|
|
|
|
|
}
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * sizeof(int)); // for dev_info
|
|
|
|
|
|
|
|
|
|
// for dev_info
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * sizeof(int));
|
|
|
|
|
// for transpose input
|
|
|
|
|
workspace_size_list_.push_back(dims_ * sizeof(size_t));
|
|
|
|
|
workspace_size_list_.push_back(dims_ * sizeof(size_t));
|
|
|
|
|
workspace_size_list_.push_back(total_size_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.push_back(total_size_ * unit_size_);
|
|
|
|
|
|
|
|
|
|
// for dev_u and dev_v
|
|
|
|
|
if (compute_uv_) {
|
|
|
|
|
if (full_matrices_) {
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * m_ * m_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * n_ * n_ * sizeof(T));
|
|
|
|
|
if (compute_uv_ || batched_) {
|
|
|
|
|
if (full_matrices_ || batched_) {
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * m_ * m_ * unit_size_);
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * n_ * n_ * unit_size_);
|
|
|
|
|
} else {
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * m_ * p_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * n_ * p_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * m_ * p_ * unit_size_);
|
|
|
|
|
workspace_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -133,7 +111,6 @@ template <typename T>
|
|
|
|
|
void SvdGpuKernelMod::RunSvd(const size_t m, const size_t n, const size_t batch, T *d_a, int *dev_info, T *output_s,
|
|
|
|
|
T *d_output_u, T *d_output_v) {
|
|
|
|
|
int lwork = 0;
|
|
|
|
|
|
|
|
|
|
if constexpr (std::is_same_v<T, float>) {
|
|
|
|
|
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSgesvd_bufferSize(handle_, m, n, &lwork),
|
|
|
|
|
"cusolver query svd work size fail");
|
|
|
|
@ -208,59 +185,35 @@ void SvdGpuKernelMod::RunSvdBatched(const size_t m, const size_t n, T *d_input,
|
|
|
|
|
"cusolver svd fail");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDestroyGesvdjInfo(info), "cusolver svd fail");
|
|
|
|
|
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(work);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, size_t *d_transpose_input_shape,
|
|
|
|
|
size_t *d_transpose_input_axis, T *d_output_u, T *d_output_v, T *output_u,
|
|
|
|
|
void SvdGpuKernelMod::TransposeUV(const size_t m, const size_t n, T *d_output_u, T *d_output_v, T *output_u,
|
|
|
|
|
T *output_v) {
|
|
|
|
|
if (full_matrices_) {
|
|
|
|
|
transpose_input_shape_[dims_ - kDim2] = m;
|
|
|
|
|
transpose_input_shape_[dims_ - kDim1] = m;
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
CalTranspose(batch_size_ * m * m, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * m), SizeToInt(m), SizeToInt(m), output_u, device_id_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
if (batched_) {
|
|
|
|
|
transpose_input_shape_[dims_ - kDim2] = n;
|
|
|
|
|
transpose_input_shape_[dims_ - kDim1] = n;
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
CalTranspose(batch_size_ * n * n, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * n), SizeToInt(n), SizeToInt(n), output_v, device_id_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * n, cudaMemcpyHostToDevice,
|
|
|
|
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * n, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
transpose_input_shape_[dims_ - kDim2] = p_;
|
|
|
|
|
transpose_input_shape_[dims_ - kDim1] = m;
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
CalTranspose(batch_size_ * m * p_, d_output_u, d_transpose_input_shape, d_transpose_input_axis, dims_, output_u,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
MatrixTranspose(d_output_u, SizeToInt(batch_size_ * m * p_), SizeToInt(p_), SizeToInt(m), output_u, device_id_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
|
|
|
|
|
if (batched_) {
|
|
|
|
|
transpose_input_shape_[dims_ - kDim2] = p_;
|
|
|
|
|
transpose_input_shape_[dims_ - kDim1] = n;
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
CalTranspose(batch_size_ * n * p_, d_output_v, d_transpose_input_shape, d_transpose_input_axis, dims_, output_v,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
MatrixTranspose(d_output_v, SizeToInt(batch_size_ * n * p_), SizeToInt(p_), SizeToInt(n), output_v, device_id_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(size_t) * batch_size_ * n * p_, cudaMemcpyHostToDevice,
|
|
|
|
|
cudaMemcpyAsync(output_v, d_output_v, sizeof(T) * batch_size_ * n * p_, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
}
|
|
|
|
@ -286,8 +239,7 @@ void SvdGpuKernelMod::CheckResult(int *dev_info) {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *output_s, T *output_u, T *output_v,
|
|
|
|
|
T *d_output_u, T *d_output_v, int *dev_info, size_t *d_transpose_input_shape,
|
|
|
|
|
size_t *d_transpose_input_axis) {
|
|
|
|
|
T *d_output_u, T *d_output_v, int *dev_info) {
|
|
|
|
|
if (batched_) {
|
|
|
|
|
RunSvdBatched(m, n, d_input, output_s, d_output_u, d_output_v, dev_info);
|
|
|
|
|
} else {
|
|
|
|
@ -297,7 +249,7 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (compute_uv_) {
|
|
|
|
|
TransposeUV(m, n, d_transpose_input_shape, d_transpose_input_axis, d_output_u, d_output_v, output_u, output_v);
|
|
|
|
|
TransposeUV(m, n, d_output_u, d_output_v, output_u, output_v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CheckResult(dev_info);
|
|
|
|
@ -306,9 +258,6 @@ void SvdGpuKernelMod::LaunchSvd(const size_t m, const size_t n, T *d_input, T *o
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs) {
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"CusolverDnSetStream failed");
|
|
|
|
|
|
|
|
|
@ -318,67 +267,55 @@ bool SvdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
|
|
|
|
T *output_v = nullptr;
|
|
|
|
|
T *d_output_u = nullptr;
|
|
|
|
|
T *d_output_v = nullptr;
|
|
|
|
|
if (compute_uv_) {
|
|
|
|
|
output_u = GetDeviceAddress<T>(outputs, kIndex1);
|
|
|
|
|
output_v = GetDeviceAddress<T>(outputs, kIndex2);
|
|
|
|
|
if (compute_uv_ || batched_) {
|
|
|
|
|
if (compute_uv_) {
|
|
|
|
|
output_u = GetDeviceAddress<T>(outputs, kIndex1);
|
|
|
|
|
output_v = GetDeviceAddress<T>(outputs, kIndex2);
|
|
|
|
|
}
|
|
|
|
|
// Store output u and v before transpose.
|
|
|
|
|
d_output_u = GetDeviceAddress<T>(workspace, kIndex4);
|
|
|
|
|
d_output_v = GetDeviceAddress<T>(workspace, kIndex5);
|
|
|
|
|
d_output_u = GetDeviceAddress<T>(workspace, kIndex2);
|
|
|
|
|
d_output_v = GetDeviceAddress<T>(workspace, kIndex3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
|
|
|
|
|
|
|
|
|
|
size_t *d_transpose_input_shape = GetDeviceAddress<size_t>(workspace, kIndex1);
|
|
|
|
|
size_t *d_transpose_input_axis = GetDeviceAddress<size_t>(workspace, kIndex2);
|
|
|
|
|
T *d_input = GetDeviceAddress<T>(workspace, kIndex3);
|
|
|
|
|
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_axis, transpose_input_axis_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
T *d_input = GetDeviceAddress<T>(workspace, kIndex1);
|
|
|
|
|
|
|
|
|
|
if (m_ge_n_) {
|
|
|
|
|
// Because cudaSovler expects column-major matrix, we need transpose A.
|
|
|
|
|
MatrixTranspose(input, SizeToInt(total_size_), SizeToInt(m_), SizeToInt(n_), d_input, device_id_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info);
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
|
|
|
cudaMemcpyAsync(d_transpose_input_shape, transpose_input_shape_, sizeof(size_t) * dims_, cudaMemcpyHostToDevice,
|
|
|
|
|
cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
CalTranspose(total_size_, input, d_transpose_input_shape, d_transpose_input_axis, dims_, d_input,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
|
|
|
|
LaunchSvd(m_, n_, d_input, output_s, output_u, output_v, d_output_u, d_output_v, dev_info, d_transpose_input_shape,
|
|
|
|
|
d_transpose_input_axis);
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(d_input, input, sizeof(T) * total_size_, cudaMemcpyHostToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
|
|
|
|
"cuda memcpy failed!");
|
|
|
|
|
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info, d_transpose_input_shape,
|
|
|
|
|
d_transpose_input_axis);
|
|
|
|
|
LaunchSvd(n_, m_, d_input, output_s, output_v, output_u, d_output_v, d_output_u, dev_info);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<KernelAttr, std::pair<SvdGpuKernelMod::LaunchKernelFunc, SvdGpuKernelMod::InitSizeListsFunc>>>
|
|
|
|
|
SvdGpuKernelMod::func_list_ = {
|
|
|
|
|
{KernelAttr()
|
|
|
|
|
.AddInputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
{&SvdGpuKernelMod::LaunchKernel<float>, &SvdGpuKernelMod::InitSizeLists<float>}},
|
|
|
|
|
{KernelAttr()
|
|
|
|
|
.AddInputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
{&SvdGpuKernelMod::LaunchKernel<double>, &SvdGpuKernelMod::InitSizeLists<double>}},
|
|
|
|
|
std::vector<std::pair<KernelAttr, SvdGpuKernelMod::LaunchKernelFunc>> SvdGpuKernelMod::func_list_ = {
|
|
|
|
|
{KernelAttr()
|
|
|
|
|
.AddInputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat32),
|
|
|
|
|
&SvdGpuKernelMod::LaunchKernel<float>},
|
|
|
|
|
{KernelAttr()
|
|
|
|
|
.AddInputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64)
|
|
|
|
|
.AddOutputAttr(kNumberTypeFloat64),
|
|
|
|
|
&SvdGpuKernelMod::LaunchKernel<double>},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<KernelAttr> SvdGpuKernelMod::GetOpSupport() {
|
|
|
|
|
std::vector<KernelAttr> support_list;
|
|
|
|
|
(void)std::transform(
|
|
|
|
|
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
|
|
|
|
[](const std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>> &pair) { return pair.first; });
|
|
|
|
|
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
|
|
|
|
[](const std::pair<KernelAttr, LaunchKernelFunc> &pair) { return pair.first; });
|
|
|
|
|
return support_list;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|