!34928 Add support on CPU platform for op: svd
Merge pull request !34928 from zhuyuxiao/I51VRQ
This commit is contained in:
commit
f8d767d734
|
@ -113,6 +113,12 @@ bool BroadcastToCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
|
||||
CheckArgs();
|
||||
|
||||
if (std::find(input_shape_.begin(), input_shape_.end(), 0) != input_shape_.end() &&
|
||||
std::find(output_shape_.begin(), output_shape_.end(), 0) != output_shape_.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int status = static_cast<int>(NNACL_OK);
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/svd_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/svd.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kSvdInputsNum = 1;
|
||||
const size_t kSvdOutputsNum = 3;
|
||||
|
||||
bool SvdCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Svd>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(EXCEPTION) << "cast Svd op failed!";
|
||||
}
|
||||
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
full_matrices_ = kernel_ptr->full_matrices();
|
||||
compute_uv_ = kernel_ptr->compute_uv();
|
||||
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSvdInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSvdOutputsNum, kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SvdCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
int SvdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &onHost) {
|
||||
int ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, onHost);
|
||||
if (ret != 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << "resize failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> input_shape = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
size_t dim = input_shape.size();
|
||||
if (dim < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", input dimension must be greater than or equal to 2.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
num_of_rows_ = input_shape[dim - kDim2];
|
||||
num_of_cols_ = input_shape[dim - kDim1];
|
||||
for (size_t i = 0; i < dim - kDim2; i++) {
|
||||
batch_size_ = batch_size_ * input_shape[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SvdCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *input_a = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
auto *output_s = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
auto *output_u = reinterpret_cast<T *>(outputs[kIndex1]->addr);
|
||||
auto *output_v = reinterpret_cast<T *>(outputs[kIndex2]->addr);
|
||||
|
||||
std::map<bool, std::pair<int, int>> optionMap{{true, {Eigen::ComputeFullU, Eigen::ComputeFullV}},
|
||||
{false, {Eigen::ComputeThinU, Eigen::ComputeThinV}}};
|
||||
std::function<void(std::size_t, std::size_t)> task;
|
||||
|
||||
if (compute_uv_) {
|
||||
task = [&](int64_t start, int64_t end) {
|
||||
for (int64_t start = 0; start < end; ++start) {
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>> matrix(
|
||||
input_a + start * num_of_rows_ * num_of_cols_, num_of_rows_, num_of_cols_);
|
||||
Eigen::BDCSVD<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>> svd(
|
||||
matrix, optionMap[full_matrices_].first | optionMap[full_matrices_].second);
|
||||
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor> s = svd.singularValues();
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor> u = svd.matrixU();
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor> v = svd.matrixV();
|
||||
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>>(output_s + start * s.rows() * s.cols(),
|
||||
s.rows(), s.cols()) = s;
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>>(output_u + start * u.rows() * u.cols(),
|
||||
u.rows(), u.cols()) = u;
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>>(output_v + start * v.rows() * v.cols(),
|
||||
v.rows(), v.cols()) = v;
|
||||
}
|
||||
};
|
||||
} else {
|
||||
task = [&](int64_t start, int64_t end) {
|
||||
for (int64_t start = 0; start < end; ++start) {
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>> matrix(
|
||||
input_a + start * num_of_rows_ * num_of_cols_, num_of_rows_, num_of_cols_);
|
||||
Eigen::BDCSVD<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>> svd(
|
||||
matrix, optionMap[full_matrices_].first | optionMap[full_matrices_].second);
|
||||
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor> s = svd.singularValues();
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, RowMajor>>(output_s + start * s.rows() * s.cols(),
|
||||
s.rows(), s.cols()) = s;
|
||||
}
|
||||
};
|
||||
}
|
||||
ParallelLaunchAutoSearch(task, batch_size_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SvdCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SvdCpuKernelMod::SvdFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SvdCpuKernelMod::SvdFunc>> SvdCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SvdCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SvdCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Svd, SvdCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SvdCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SvdCpuKernelMod() {}
|
||||
~SvdCpuKernelMod() = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
using SvdFunc = std::function<bool(SvdCpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
static std::vector<std::pair<KernelAttr, SvdFunc>> func_list_;
|
||||
SvdFunc kernel_func_;
|
||||
bool full_matrices_{false};
|
||||
bool compute_uv_{true};
|
||||
int64_t batch_size_ = 1;
|
||||
int64_t num_of_rows_;
|
||||
int64_t num_of_cols_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H
|
|
@ -3342,7 +3342,12 @@ class Tensor(Tensor_):
|
|||
[[-0.6386359 0.7695091]
|
||||
[-0.7695091 -0.6386359]]
|
||||
"""
|
||||
return tensor_operator_registry.get("svd")(full_matrices, compute_uv)(self)
|
||||
svd_op = tensor_operator_registry.get("svd")
|
||||
if compute_uv:
|
||||
return svd_op(full_matrices, compute_uv)(self)
|
||||
|
||||
s, _, _ = svd_op(full_matrices, compute_uv)(self)
|
||||
return s
|
||||
|
||||
def hardshrink(self, lambd=0.5):
|
||||
r"""
|
||||
|
|
|
@ -68,7 +68,14 @@ def svd(a, full_matrices=False, compute_uv=True):
|
|||
[-0.7695091 -0.6386359]]
|
||||
"""
|
||||
svd_ = linalg_ops.Svd(full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
return svd_(a)
|
||||
|
||||
if compute_uv:
|
||||
return svd_(a)
|
||||
|
||||
s, _, _ = svd_(a)
|
||||
return s
|
||||
|
||||
|
||||
|
||||
|
||||
__all__ = ['svd']
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
from mindspore import context, ops, nn, Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.operations import linalg_ops, array_ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
RTOL = 1.e-5
|
||||
ATOL = 1.e-6
|
||||
|
||||
k_0 = Tensor(0, mindspore.int32)
|
||||
matmul = ops.MatMul()
|
||||
batch_matmul = ops.BatchMatMul()
|
||||
transpose = ops.Transpose()
|
||||
|
||||
|
||||
@constexpr
|
||||
def make_zero_matrix(shape, dtype):
|
||||
return Tensor(np.zeros(shape), dtype)
|
||||
|
||||
|
||||
def matrix_diag(diagonal, shape):
|
||||
assist_matrix = make_zero_matrix(shape, ops.DType()(diagonal))
|
||||
return array_ops.MatrixSetDiagV3()(assist_matrix, diagonal, k_0)
|
||||
|
||||
|
||||
class SvdNet(nn.Cell):
|
||||
def __init__(self, full_matrices=False, compute_uv=True):
|
||||
super(SvdNet, self).__init__()
|
||||
self.svd = linalg_ops.Svd(full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
def construct(self, a):
|
||||
return self.svd(a)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net1():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: m >= n and full_matrices=False, compute_uv=False
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(3, 2)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
mscp_svd_net = SvdNet(False, False)
|
||||
s, _, _ = mscp_svd_net(tensor_a)
|
||||
n_s = np.linalg.svd(a, full_matrices=False, compute_uv=False)
|
||||
assert np.allclose(n_s, s.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net2():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: m >= n and full_matrices=True, compute_uv=True
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(3, 2)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float64)
|
||||
mscp_svd_net = SvdNet(True, True)
|
||||
s, u, v = mscp_svd_net(tensor_a)
|
||||
|
||||
output = matmul(u, matmul(matrix_diag(s, (3, 2)), transpose(v, (1, 0))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net3():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: m >= n and full_matrices=False, compute_uv=True
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(3, 2)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
s, u, v = ops.svd(tensor_a, False, True)
|
||||
output = matmul(u, matmul(matrix_diag(s, (2, 2)), transpose(v, (1, 0))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net4():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: m < n and full_matrices=True, compute_uv=True
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(2, 3)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float64)
|
||||
s, u, v = ops.svd(tensor_a, True, True)
|
||||
output = matmul(u, matmul(matrix_diag(s, (2, 3)), transpose(v, (1, 0))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net5():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: inputs shape is (a, b, m, n), m > n
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(5, 5, 3, 2)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
s, u, v = ops.svd(tensor_a, True, True)
|
||||
|
||||
output = batch_matmul(u, batch_matmul(matrix_diag(s, (5, 5, 3, 2)), transpose(v, (0, 1, 3, 2))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_net6():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: specific input 3*2
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.array([[1, 2], [-4, -5], [2, 1]])
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
s, u, v = linalg_ops.Svd(full_matrices=True, compute_uv=True)(tensor_a)
|
||||
output = matmul(u, matmul(matrix_diag(s, (3, 2)), transpose(v, (1, 0))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_vmap1():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: vmap
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(5, 3, 3)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
net = SvdNet(True, True)
|
||||
svd_vmap = ops.vmap(net, (0,), 0)
|
||||
s, u, v = svd_vmap(tensor_a)
|
||||
output = batch_matmul(u, batch_matmul(matrix_diag(s, (5, 3, 3)), transpose(v, (0, 2, 1))))
|
||||
assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_svd_vmap2():
|
||||
"""
|
||||
Feature: Svd
|
||||
Description: test cases for svd: vmap
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
a = np.random.rand(5, 3, 3)
|
||||
tensor_a = Tensor(a, dtype=mindspore.float32)
|
||||
net = SvdNet(True, False)
|
||||
svd_vmap = ops.vmap(net, (0,), 0)
|
||||
s, _, _ = svd_vmap(tensor_a)
|
||||
n_s = np.linalg.svd(a, full_matrices=True, compute_uv=False)
|
||||
assert np.allclose(n_s, s.asnumpy(), rtol=RTOL, atol=ATOL)
|
Loading…
Reference in New Issue