From 37bb297bc78f1fcc47ca5a65e3b76cba8c53f628 Mon Sep 17 00:00:00 2001 From: zy <1228862915@qq.com> Date: Wed, 19 Jan 2022 11:02:48 +0800 Subject: [PATCH] [feat] [assistant] [I40GI8] Add new operation LuSolve --- .../cpu/lu_solve_cpu_kernel.cc | 177 ++++++++++++++++++ .../kernel_compiler/cpu/lu_solve_cpu_kernel.h | 64 +++++++ mindspore/core/base/core_ops.h | 1 + mindspore/core/ops/lu_solve_.cc | 162 ++++++++++++++++ mindspore/core/ops/lu_solve_.h | 43 +++++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 1 + .../mindspore/ops/_op_impl/aicpu/lu_solve.py | 32 ++++ .../mindspore/ops/operations/__init__.py | 3 +- .../mindspore/ops/operations/math_ops.py | 53 ++++++ tests/ut/python/ops/test_ops.py | 6 + 10 files changed, 541 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h create mode 100644 mindspore/core/ops/lu_solve_.cc create mode 100644 mindspore/core/ops/lu_solve_.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.cc new file mode 100644 index 00000000000..a86a0fbe66f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kDimNum = 2; +} + +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} + +template +void LuSolveCPUKernel::InitKernel(const CNodePtr &kernel_node) { + node_wpt_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputNum, kernel_name_); + auto x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto lu_data_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto lu_pivots_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + if (lu_data_shape.size() < kDimNum) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_data's dimensions should be greater than or equal to 2."; + } + if (x_shape.size() < kDimNum) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel x's dimensions should be greater than or equal to 2."; + } + if (lu_pivots_shape.size() < 1) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_pivots's dimensions should be greater than or equal to 1."; + } + if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum]) + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel " + << " input lu_data should be square matrix " + << "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is " + << lu_data_shape[lu_data_shape.size() - 1] << "."; + + if (x_shape.size() == lu_data_shape.size()) { + for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) { + if (x_shape[i] != lu_data_shape[i]) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel " + << " shapes in dim[" << i << "] are not the same " + << "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << "."; + } + } + } else if (lu_data_shape.size() > x_shape.size()) { + for (size_t i = 0; i < x_shape.size() - kDimNum; i++) { + if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel" + << " shapes in dim[" << i << "] are not same as lu_data's dim[" + << lu_data_shape.size() - x_shape.size() + i << "]" + << "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << "."; + } + } + } else { + for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) { + if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel " + << " shapes in lu_data's dim[" << i << "] are not same as x's dim[" + << x_shape.size() - lu_data_shape.size() + i << "]" + << "while x is " << x_shape[x_shape.size() - lu_data_shape.size() + i] + << ", lu_data is " << lu_data_shape[i] << "."; + } + } + } + if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel " + << " Number of pivots per batch should be same as the dimension of the matrix."; + } + for (size_t i = 0; i < lu_pivots_shape.size(); i++) { + if (lu_data_shape[i] != lu_pivots_shape[i]) { + MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel " + << "batch dimension of LU_pivots should match batch dimension of LU_data."; + } + } +} + +template +void LuSolveCPUKernel::LuSolve(const std::vector &inputs, + const std::vector &outputs, T1 *b_working_ptr, + T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a) { + auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0); + auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1); + auto output_y = reinterpret_cast(outputs[0]->addr); + size_t lu_dims = input_1_Shape.size(); + size_t lu_maxtrix_sizes = input_1_Shape[lu_dims - 2]; + size_t b_dim = input_0_Shape.size(); + size_t b_m = input_0_Shape[b_dim - 1]; + typedef Eigen::Matrix MatrixXd; + MatrixXd matrix_b = Eigen::Map(b_working_ptr, lu_maxtrix_sizes, b_m); + MatrixXd matrix_A = Eigen::Map(lu_working_ptr, lu_maxtrix_sizes, lu_maxtrix_sizes); + for (size_t i = 0; i < input_0_Shape[b_dim - kDimNum]; i++) { + matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1)); + } + MatrixXd L = matrix_A.template triangularView(); + MatrixXd U = matrix_A.template triangularView(); + MatrixXd x = L * U; + MatrixXd result = x.lu().solve(matrix_b); + for (size_t m = 0; m < b_stride; m++) { + *(output_y + a * b_stride + m) = (T2) * (result.data() + m); + } +} + +template +bool LuSolveCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + auto input_x0 = reinterpret_cast(inputs[0]->addr); + auto input_x1 = reinterpret_cast(inputs[1]->addr); + auto input_x2 = reinterpret_cast(inputs[2]->addr); + auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0); + auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1); + auto output_Shape = AnfAlgo::GetOutputDeviceShape(node_wpt_, 0); + size_t input0_element_num = get_element_num(input_0_Shape); + size_t input1_element_num = get_element_num(input_1_Shape); + size_t output_element_num = get_element_num(output_Shape); + std::vector input_0(input_x0, input_x0 + input0_element_num); + std::vector input_1(input_x1, input_x1 + input1_element_num); + size_t b_dims = input_0_Shape.size(); + std::vector b_dims_vector = input_0_Shape; + size_t lu_dims = input_1_Shape.size(); + std::vector lu_dims_vector = input_1_Shape; + size_t b_stride = input_0_Shape[b_dims - 1] * input_0_Shape[b_dims - 2]; + size_t lu_stride = input_1_Shape[lu_dims - 1] * input_1_Shape[lu_dims - 2]; + size_t pivots_stride = input_1_Shape[lu_dims - 1]; + MS_EXCEPTION_IF_ZERO("b_stride", b_stride); + size_t batch_num = output_element_num / b_stride; + if (b_dims == lu_dims) { + for (size_t i = 0; i < batch_num; i++) { + T1 *b_working_ptr = input_0.data() + i * b_stride; + T1 *lu_working_ptr = input_1.data() + i * lu_stride; + int32_t *pivots_working_ptr = &input_x2[i * pivots_stride]; + LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); + } + } else { + std::vector b_shape = b_dims_vector; + std::vector lu_shape = lu_dims_vector; + for (size_t i = 0; i < kDimNum; i++) { + b_shape.pop_back(); + lu_shape.pop_back(); + } + auto output_shape = CPUKernelUtils::GetBroadcastShape(b_shape, lu_shape); + BroadcastIterator iter(b_shape, lu_shape, output_shape); + iter.SetPos(0); + for (size_t i = 0; i < batch_num; i++) { + T1 *b_working_ptr = input_0.data() + iter.GetInputPosA() * b_stride; + T1 *lu_working_ptr = input_1.data() + iter.GetInputPosB() * lu_stride; + int32_t *pivots_working_ptr = &input_x2[iter.GetInputPosB() * pivots_stride]; + LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); + iter.GenNextPos(); + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h new file mode 100644 index 00000000000..1875b8f93a8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +constexpr size_t kInputNum = 3; +constexpr size_t kOutputNum = 1; +namespace kernel { +template +class LuSolveCPUKernel : public CPUKernel { + public: + LuSolveCPUKernel() = default; + ~LuSolveCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + void LuSolve(const std::vector &inputs, const std::vector &outputs, + T1 *b_working_ptr, T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a); + + private: + CNodePtr node_wpt_; + std::vector input_0_Shape; + std::vector input_1_Shape; + std::vector input_2_Shape; +}; + +MS_REG_CPU_KERNEL_T_S(LuSolve, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + LuSolveCPUKernel, float, float16); +MS_REG_CPU_KERNEL_T_S(LuSolve, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + LuSolveCPUKernel, float, float); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 02791b0682c..72242753a94 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -557,6 +557,7 @@ inline const PrimitivePtr kPrimBesselI0 = std::make_shared("BesselI0" inline const PrimitivePtr kPrimBesselI1 = std::make_shared("BesselI1"); inline const PrimitivePtr kPrimGer = std::make_shared("Ger"); inline const PrimitivePtr kPrimCeil = std::make_shared("Ceil"); +inline const PrimitivePtr kPrimLuSolve = std::make_shared("LuSolve"); inline const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); inline const PrimitivePtr kPrimAdd = std::make_shared(kAdd); inline const PrimitivePtr kPrimAddcdiv = std::make_shared(kAddcdiv); diff --git a/mindspore/core/ops/lu_solve_.cc b/mindspore/core/ops/lu_solve_.cc new file mode 100644 index 00000000000..4c1022b87fa --- /dev/null +++ b/mindspore/core/ops/lu_solve_.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/lu_solve_.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +#define LuSolve_for(shape) \ + do { \ + for (auto item : shape) { \ + buffer << item << " "; \ + } \ + } while (0) + +#define LuSolve_pop() \ + do { \ + for (size_t i = 0; i < 2; i++) { \ + x_shape.pop_back(); \ + lu_data_shape.pop_back(); \ + } \ + } while (0) + +#define LuSolve_buffer(x_shape, lu_data_shape) \ + do { \ + LuSolve_pop(); \ + buffer << "For LuSolve x's batch dimension does not match lu_data's batch dimension, x's batch dimension is ["; \ + LuSolve_for(x_shape); \ + buffer << "], lu_data's batch dimension is ["; \ + LuSolve_for(lu_data_shape); \ + buffer << "], the batch dimensions may have different sizes, "; \ + buffer << "from right to left, the corresponding dimensions must be equal."; \ + } while (0) + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + const int64_t kDimNum = 2; + std::ostringstream buffer; + auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto x_shape = x_shape_map[kShape]; + auto lu_data_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); + auto lu_data_shape = lu_data_shape_map[kShape]; + auto lu_pivots_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape()); + auto lu_pivots_shape = lu_pivots_shape_map[kShape]; + if (lu_data_shape.size() < kDimNum) { + MS_EXCEPTION(ValueError) << "For " << op_name << " lu_data's dimensions should be greater than or equal to 2."; + } + if (x_shape.size() < kDimNum) { + MS_EXCEPTION(ValueError) << "For " << op_name << " x's dimensions should be greater than or equal to 2."; + } + if (lu_pivots_shape.size() < 1) { + MS_EXCEPTION(ValueError) << "For " << op_name << " lu_pivots's dimensions should be greater than or equal to 1."; + } + if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum]) { + MS_EXCEPTION(ValueError) << "For " << op_name << " input lu_data should be square matrix " + << "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is " + << lu_data_shape[lu_data_shape.size() - 1] << "."; + } + if (x_shape[x_shape.size() - kDimNum] != lu_data_shape[lu_data_shape.size() - kDimNum]) { + MS_EXCEPTION(ValueError) << "For " << op_name << " x's col rank is not same as lu_data's col rank. " + << "x is " << x_shape[x_shape.size() - kDimNum] << ", lu_data is " + << lu_data_shape[lu_data_shape.size() - kDimNum] << "."; + } + if (x_shape.size() == lu_data_shape.size()) { + for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) { + if (x_shape[i] != lu_data_shape[i]) { + LuSolve_buffer(x_shape, lu_data_shape); + MS_EXCEPTION(ValueError) << buffer.str(); + } + } + } else if (lu_data_shape.size() > x_shape.size()) { + for (size_t i = 0; i < x_shape.size() - kDimNum; i++) { + if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) { + LuSolve_buffer(x_shape, lu_data_shape); + MS_EXCEPTION(ValueError) << buffer.str(); + } + } + } else { + for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) { + if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) { + LuSolve_buffer(x_shape, lu_data_shape); + MS_EXCEPTION(ValueError) << buffer.str(); + } + } + } + if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) { + MS_EXCEPTION(ValueError) << "For " << op_name + << " the last dimension of lu_pivots must be equal to the last dimension of lu_data, " + << "lu_data is " << lu_data_shape[lu_data_shape.size() - 1] << ", lu_pivots is " + << lu_pivots_shape[lu_pivots_shape.size() - 1] << "."; + } + for (size_t i = 0; i < lu_pivots_shape.size(); i++) { + if (lu_data_shape[i] != lu_pivots_shape[i]) { + x_shape.pop_back(); + x_shape.pop_back(); + lu_pivots_shape.pop_back(); + buffer << "For " << op_name + << " lu_data's batch dimension does not match lu_pivots's batch dimension, lu_data's batch dimension is ["; + LuSolve_for(x_shape); + buffer << "], lu_pivots's batch dimension is ["; + LuSolve_for(lu_pivots_shape); + buffer << "], the size of the dimension and the number of each dimension must be the same."; + MS_EXCEPTION(ValueError) << buffer.str(); + } + } + auto dim_vector = lu_data_shape; + if (x_shape.size() >= lu_data_shape.size()) { + return std::make_shared(x_shape); + } else { + dim_vector[lu_data_shape.size() - 1] = x_shape[x_shape.size() - 1]; + return std::make_shared(dim_vector); + } +} + +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + const int64_t kDimNum = 2; + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + std::map type; + (void)type.emplace("x", input_args[0]->BuildType()); + (void)type.emplace("lu_data", input_args[1]->BuildType()); + const std::set valid_types = {kFloat32, kFloat16}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("lu_data", input_args[1]->BuildType(), valid_types, prim->name()); + auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(type, valid_types, prim->name()); + const std::set valid_lu_pivots_types = {kInt32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("lu_pivots", input_args[kDimNum]->BuildType(), valid_lu_pivots_types, + prim->name()); + return out_type; +} +} // namespace + +AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto infer_type = InferType(primitive, input_args); + auto infer_shape = InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(LuSolve, prim::kPrimLuSolve, LuSolveInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/lu_solve_.h b/mindspore/core/ops/lu_solve_.h new file mode 100644 index 00000000000..9d533cde8a1 --- /dev/null +++ b/mindspore/core/ops/lu_solve_.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +#ifndef MINDSPORE_CORE_OPS_LUSOLVE_H_ +#define MINDSPORE_CORE_OPS_LUSOLVE_H_ + +#include +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLuSolve = "LuSolve"; +class LuSolve : public PrimitiveC { + public: + LuSolve() : PrimitiveC(kNameLuSolve) { InitIOName({"x", "lu_data", "lu_pivots"}, {"output"}); } + ~LuSolve() = default; + MS_DECLARE_PARENT(LuSolve, PrimitiveC); +}; +AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimLuSolvePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_LUSOLVE_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 6bbe69078d0..ca2eada126e 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -14,6 +14,7 @@ """aicpu ops""" from .unique import _unique_aicpu +from .lu_solve import _lu_solve_aicpu from .no_repeat_ngram import _no_repeat_ngram_aicpu from .init_data_set_queue import _init_data_set_queue_aicpu from .embedding_lookup import _embedding_lookup_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py new file mode 100644 index 00000000000..d0165d96a49 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LuSolve op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +lu_solve_op_info = AiCPURegOp("LuSolve") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "lu_data", "required") \ + .input(2, "lu_pivots", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .get_op_info() + +@op_info_register(lu_solve_op_info) +def _lu_solve_aicpu(): + """LuSolve aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 35d26655460..6c3838ae486 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -61,7 +61,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0, BesselI1, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv, Addcmul, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, - Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose) + Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, @@ -507,6 +507,7 @@ __all__ = [ "NeighborExchange", "AlltoAll", "Custom", + "LuSolve", ] __sponge__ = [ diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index bb13d6a4941..01eb566d7a0 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -5779,6 +5779,7 @@ class Trunc(Primitive): def __init__(self): """Initialize Trunc""" + class IsClose(Primitive): r""" Returns a boolean tensor where two tensors are element-wise equal within a tolerance. @@ -5835,3 +5836,55 @@ class IsClose(Primitive): raise ValueError("For IsClose, the `equal_nan` must be True, but got False.") validator.check_non_negative_float(rtol, 'rtol', self.name) validator.check_non_negative_float(atol, 'atol', self.name) + + +class LuSolve(Primitive): + """ + Return the solution of the linear equation Ax = b. + + Note: + The batch dimensions of lu_pivots must match the batch dimensions of lu_data, the size of the dimension and the + number of each dimension must be the same. For example, lu_data is (3, 3, 2, 2) lu_pivots is (3, 3, 2), + lu_data's batch dimensions is (3, 3), lu_pivots's batch dimensions is (3, 3). + + The batch dimensions of lu_data must match the batch dimensions of x, the batch dimensions may have + different sizes, from right to left, the corresponding dimensions must be equal. For example, lu_data + is (3, 3, 2, 2) x is (2, 3, 3, 2, 1), lu_data's batch dimensions is (3, 3), x's batch dimensions is (2, 3, 3). + + Inputs: + - **x** (Tensor) - The input is a tensor of size (*, m, k), where * is batch dimensions, with data type + float32, float16. + - **lu_data** (Tensor) - The input is a tensor of size (*, m, m), where * is batch dimensions, that can + be decomposed into an upper + triangular matrix U and a lower triangular matrix L, with data type float32, float16. + - **lu_pivots** (Tensor) - The input is a tensor of size (*, m), where * is batch dimensions, that can + be converted to a permutation matrix P, with data type int32. + + Outputs: + Tensor, the same data type as the x and lu_data. + + Raises: + TypeError: If dtype of `x` or `lu_data` is not one of: float32, float16. + TypeError: If dtype of `lu_pivots` is not: int32. + TypeError: If `x`, `lu_data` or `lu_pivots` is not Tensor. + TypeError: If dtype of `x` is not same as dtype of `lu_data`. + ValueError: If the batch dimensions of lu_pivots does not match the batch dimensions of lu_data. + ValueError: If `x` dimension less than 2, `lu_data` dimension less than 2 or `lu_pivots` dimension less than 1. + + Supported Platforms: + ``CPU`` + + Examples: + >>> x = Tensor(np.array([[1], [3], [3]]), mindspore.float32) + >>> lu_data = Tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), mindspore.float32) + >>> lu_pivots = Tensor(np.array([2, 2, 3]), mindspore.int32) + >>> net = ops.LuSolve() + >>> y = net(x, lu_data, lu_pivots) + >>> print(y) + [[ 1.9000002] + [-1.4000001] + [ 0.6 ]] + """ + @prim_attr_register + def __init__(self): + pass diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f861528d956..41ee3603882 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1543,6 +1543,12 @@ test_case_math_ops = [ 'block': P.Square(), 'desc_inputs': [[4]], 'desc_bprop': [[4]]}), + ('LuSolve', { + 'block': P.LuSolve(), + 'desc_inputs': [Tensor(np.random.rand(3, 3), mstype.float32), + Tensor(np.random.rand(3, 3), mstype.float32), + Tensor(np.random.rand(3), mstype.int32)], + 'skip': ['backward']}), ('Rsqrt', { 'block': P.Rsqrt(), 'desc_inputs': [[4]],