diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/broadcast_to_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/broadcast_to_cpu_kernel.cc index aa6447d709d..0aa7656b45d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/broadcast_to_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/broadcast_to_cpu_kernel.cc @@ -113,6 +113,12 @@ bool BroadcastToCpuKernelMod::LaunchKernel(const std::vector &inputs const std::vector &outputs) { CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_); CheckArgs(); + + if (std::find(input_shape_.begin(), input_shape_.end(), 0) != input_shape_.end() && + std::find(output_shape_.begin(), output_shape_.end(), 0) != output_shape_.end()) { + return true; + } + const auto *input_addr = reinterpret_cast(inputs[0]->addr); auto *output_addr = reinterpret_cast(outputs[0]->addr); int status = static_cast(NNACL_OK); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.cc new file mode 100644 index 00000000000..817011a3906 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/svd_cpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "mindspore/core/ops/svd.h" +#include "include/common/thread_pool.h" + +namespace mindspore { +namespace kernel { +const size_t kSvdInputsNum = 1; +const size_t kSvdOutputsNum = 3; + +bool SvdCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + if (!kernel_ptr) { + MS_LOG(EXCEPTION) << "cast Svd op failed!"; + } + + kernel_name_ = kernel_ptr->name(); + full_matrices_ = kernel_ptr->full_matrices(); + compute_uv_ = kernel_ptr->compute_uv(); + + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSvdInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSvdOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this data type: " << kernel_attr; + return false; + } + + kernel_func_ = func_list_[index].second; + return true; +} + +bool SvdCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs); +} + +int SvdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &onHost) { + int ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, onHost); + if (ret != 0) { + MS_LOG(WARNING) << kernel_name_ << "resize failed."; + return ret; + } + + std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); + size_t dim = input_shape.size(); + if (dim < kDim2) { + MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", input dimension must be greater than or equal to 2."; + return -1; + } + + num_of_rows_ = input_shape[dim - kDim2]; + num_of_cols_ = input_shape[dim - kDim1]; + for (size_t i = 0; i < dim - kDim2; i++) { + batch_size_ = batch_size_ * input_shape[i]; + } + return ret; +} + +template +bool SvdCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + auto *input_a = reinterpret_cast(inputs[kIndex0]->addr); + auto *output_s = reinterpret_cast(outputs[kIndex0]->addr); + auto *output_u = reinterpret_cast(outputs[kIndex1]->addr); + auto *output_v = reinterpret_cast(outputs[kIndex2]->addr); + + std::map> optionMap{{true, {Eigen::ComputeFullU, Eigen::ComputeFullV}}, + {false, {Eigen::ComputeThinU, Eigen::ComputeThinV}}}; + std::function task; + + if (compute_uv_) { + task = [&](int64_t start, int64_t end) { + for (int64_t start = 0; start < end; ++start) { + Eigen::Map> matrix( + input_a + start * num_of_rows_ * num_of_cols_, num_of_rows_, num_of_cols_); + Eigen::BDCSVD> svd( + matrix, optionMap[full_matrices_].first | optionMap[full_matrices_].second); + + Eigen::Matrix s = svd.singularValues(); + Eigen::Matrix u = svd.matrixU(); + Eigen::Matrix v = svd.matrixV(); + + Eigen::Map>(output_s + start * s.rows() * s.cols(), + s.rows(), s.cols()) = s; + Eigen::Map>(output_u + start * u.rows() * u.cols(), + u.rows(), u.cols()) = u; + Eigen::Map>(output_v + start * v.rows() * v.cols(), + v.rows(), v.cols()) = v; + } + }; + } else { + task = [&](int64_t start, int64_t end) { + for (int64_t start = 0; start < end; ++start) { + Eigen::Map> matrix( + input_a + start * num_of_rows_ * num_of_cols_, num_of_rows_, num_of_cols_); + Eigen::BDCSVD> svd( + matrix, optionMap[full_matrices_].first | optionMap[full_matrices_].second); + + Eigen::Matrix s = svd.singularValues(); + Eigen::Map>(output_s + start * s.rows() * s.cols(), + s.rows(), s.cols()) = s; + } + }; + } + ParallelLaunchAutoSearch(task, batch_size_, this, ¶llel_search_info_); + return true; +} + +std::vector SvdCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +std::vector> SvdCpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &SvdCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &SvdCpuKernelMod::LaunchKernel}}; + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Svd, SvdCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.h new file mode 100644 index 00000000000..3c18ee336af --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/svd_cpu_kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SvdCpuKernelMod : public NativeCpuKernelMod { + public: + SvdCpuKernelMod() {} + ~SvdCpuKernelMod() = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using SvdFunc = std::function &, + const std::vector &, const std::vector &)>; + + private: + static std::vector> func_list_; + SvdFunc kernel_func_; + bool full_matrices_{false}; + bool compute_uv_{true}; + int64_t batch_size_ = 1; + int64_t num_of_rows_; + int64_t num_of_cols_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SVD_CPU_KERNEL_H diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 546f99539e2..d6d73fd0eeb 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -3342,7 +3342,12 @@ class Tensor(Tensor_): [[-0.6386359 0.7695091] [-0.7695091 -0.6386359]] """ - return tensor_operator_registry.get("svd")(full_matrices, compute_uv)(self) + svd_op = tensor_operator_registry.get("svd") + if compute_uv: + return svd_op(full_matrices, compute_uv)(self) + + s, _, _ = svd_op(full_matrices, compute_uv)(self) + return s def hardshrink(self, lambd=0.5): r""" diff --git a/mindspore/python/mindspore/ops/function/linalg_func.py b/mindspore/python/mindspore/ops/function/linalg_func.py index 68defab58bd..85700fd6044 100644 --- a/mindspore/python/mindspore/ops/function/linalg_func.py +++ b/mindspore/python/mindspore/ops/function/linalg_func.py @@ -68,7 +68,14 @@ def svd(a, full_matrices=False, compute_uv=True): [-0.7695091 -0.6386359]] """ svd_ = linalg_ops.Svd(full_matrices=full_matrices, compute_uv=compute_uv) - return svd_(a) + + if compute_uv: + return svd_(a) + + s, _, _ = svd_(a) + return s + + __all__ = ['svd'] diff --git a/tests/st/ops/cpu/test_svd_op.py b/tests/st/ops/cpu/test_svd_op.py new file mode 100644 index 00000000000..dce6c2b061c --- /dev/null +++ b/tests/st/ops/cpu/test_svd_op.py @@ -0,0 +1,184 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore +from mindspore import context, ops, nn, Tensor +from mindspore.ops.primitive import constexpr +from mindspore.ops.operations import linalg_ops, array_ops + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +RTOL = 1.e-5 +ATOL = 1.e-6 + +k_0 = Tensor(0, mindspore.int32) +matmul = ops.MatMul() +batch_matmul = ops.BatchMatMul() +transpose = ops.Transpose() + + +@constexpr +def make_zero_matrix(shape, dtype): + return Tensor(np.zeros(shape), dtype) + + +def matrix_diag(diagonal, shape): + assist_matrix = make_zero_matrix(shape, ops.DType()(diagonal)) + return array_ops.MatrixSetDiagV3()(assist_matrix, diagonal, k_0) + + +class SvdNet(nn.Cell): + def __init__(self, full_matrices=False, compute_uv=True): + super(SvdNet, self).__init__() + self.svd = linalg_ops.Svd(full_matrices=full_matrices, compute_uv=compute_uv) + + def construct(self, a): + return self.svd(a) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net1(): + """ + Feature: Svd + Description: test cases for svd: m >= n and full_matrices=False, compute_uv=False + Expectation: the result match to numpy + """ + a = np.random.rand(3, 2) + tensor_a = Tensor(a, dtype=mindspore.float32) + mscp_svd_net = SvdNet(False, False) + s, _, _ = mscp_svd_net(tensor_a) + n_s = np.linalg.svd(a, full_matrices=False, compute_uv=False) + assert np.allclose(n_s, s.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net2(): + """ + Feature: Svd + Description: test cases for svd: m >= n and full_matrices=True, compute_uv=True + Expectation: the result match to numpy + """ + a = np.random.rand(3, 2) + tensor_a = Tensor(a, dtype=mindspore.float64) + mscp_svd_net = SvdNet(True, True) + s, u, v = mscp_svd_net(tensor_a) + + output = matmul(u, matmul(matrix_diag(s, (3, 2)), transpose(v, (1, 0)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net3(): + """ + Feature: Svd + Description: test cases for svd: m >= n and full_matrices=False, compute_uv=True + Expectation: the result match to numpy + """ + a = np.random.rand(3, 2) + tensor_a = Tensor(a, dtype=mindspore.float32) + s, u, v = ops.svd(tensor_a, False, True) + output = matmul(u, matmul(matrix_diag(s, (2, 2)), transpose(v, (1, 0)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net4(): + """ + Feature: Svd + Description: test cases for svd: m < n and full_matrices=True, compute_uv=True + Expectation: the result match to numpy + """ + a = np.random.rand(2, 3) + tensor_a = Tensor(a, dtype=mindspore.float64) + s, u, v = ops.svd(tensor_a, True, True) + output = matmul(u, matmul(matrix_diag(s, (2, 3)), transpose(v, (1, 0)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net5(): + """ + Feature: Svd + Description: test cases for svd: inputs shape is (a, b, m, n), m > n + Expectation: the result match to numpy + """ + a = np.random.rand(5, 5, 3, 2) + tensor_a = Tensor(a, dtype=mindspore.float32) + s, u, v = ops.svd(tensor_a, True, True) + + output = batch_matmul(u, batch_matmul(matrix_diag(s, (5, 5, 3, 2)), transpose(v, (0, 1, 3, 2)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_net6(): + """ + Feature: Svd + Description: test cases for svd: specific input 3*2 + Expectation: the result match to numpy + """ + a = np.array([[1, 2], [-4, -5], [2, 1]]) + tensor_a = Tensor(a, dtype=mindspore.float32) + s, u, v = linalg_ops.Svd(full_matrices=True, compute_uv=True)(tensor_a) + output = matmul(u, matmul(matrix_diag(s, (3, 2)), transpose(v, (1, 0)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_vmap1(): + """ + Feature: Svd + Description: test cases for svd: vmap + Expectation: the result match to numpy + """ + a = np.random.rand(5, 3, 3) + tensor_a = Tensor(a, dtype=mindspore.float32) + net = SvdNet(True, True) + svd_vmap = ops.vmap(net, (0,), 0) + s, u, v = svd_vmap(tensor_a) + output = batch_matmul(u, batch_matmul(matrix_diag(s, (5, 3, 3)), transpose(v, (0, 2, 1)))) + assert np.allclose(a, output.asnumpy(), rtol=RTOL, atol=ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_svd_vmap2(): + """ + Feature: Svd + Description: test cases for svd: vmap + Expectation: the result match to numpy + """ + a = np.random.rand(5, 3, 3) + tensor_a = Tensor(a, dtype=mindspore.float32) + net = SvdNet(True, False) + svd_vmap = ops.vmap(net, (0,), 0) + s, _, _ = svd_vmap(tensor_a) + n_s = np.linalg.svd(a, full_matrices=True, compute_uv=False) + assert np.allclose(n_s, s.asnumpy(), rtol=RTOL, atol=ATOL)