!23958 [assistant][ops] Add LuSolve

Merge pull request !23958 from 张渝/lu_solve
This commit is contained in:
i-robot 2022-01-20 01:42:24 +00:00 committed by Gitee
commit 46fdab3848
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 541 additions and 1 deletions

View File

@ -0,0 +1,177 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kDimNum = 2;
}
size_t get_element_num(const std::vector<size_t> &shape) {
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= shape[i];
}
return size;
}
template <typename T1, typename T2>
void LuSolveCPUKernel<T1, T2>::InitKernel(const CNodePtr &kernel_node) {
node_wpt_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputNum, kernel_name_);
auto x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto lu_data_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto lu_pivots_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
if (lu_data_shape.size() < kDimNum) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_data's dimensions should be greater than or equal to 2.";
}
if (x_shape.size() < kDimNum) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel x's dimensions should be greater than or equal to 2.";
}
if (lu_pivots_shape.size() < 1) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_pivots's dimensions should be greater than or equal to 1.";
}
if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum])
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
<< " input lu_data should be square matrix "
<< "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is "
<< lu_data_shape[lu_data_shape.size() - 1] << ".";
if (x_shape.size() == lu_data_shape.size()) {
for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) {
if (x_shape[i] != lu_data_shape[i]) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
<< " shapes in dim[" << i << "] are not the same "
<< "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << ".";
}
}
} else if (lu_data_shape.size() > x_shape.size()) {
for (size_t i = 0; i < x_shape.size() - kDimNum; i++) {
if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel"
<< " shapes in dim[" << i << "] are not same as lu_data's dim["
<< lu_data_shape.size() - x_shape.size() + i << "]"
<< "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << ".";
}
}
} else {
for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) {
if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
<< " shapes in lu_data's dim[" << i << "] are not same as x's dim["
<< x_shape.size() - lu_data_shape.size() + i << "]"
<< "while x is " << x_shape[x_shape.size() - lu_data_shape.size() + i]
<< ", lu_data is " << lu_data_shape[i] << ".";
}
}
}
if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
<< " Number of pivots per batch should be same as the dimension of the matrix.";
}
for (size_t i = 0; i < lu_pivots_shape.size(); i++) {
if (lu_data_shape[i] != lu_pivots_shape[i]) {
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
<< "batch dimension of LU_pivots should match batch dimension of LU_data.";
}
}
}
template <typename T1, typename T2>
void LuSolveCPUKernel<T1, T2>::LuSolve(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs, T1 *b_working_ptr,
T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a) {
auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
auto output_y = reinterpret_cast<T2 *>(outputs[0]->addr);
size_t lu_dims = input_1_Shape.size();
size_t lu_maxtrix_sizes = input_1_Shape[lu_dims - 2];
size_t b_dim = input_0_Shape.size();
size_t b_m = input_0_Shape[b_dim - 1];
typedef Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
MatrixXd matrix_b = Eigen::Map<MatrixXd>(b_working_ptr, lu_maxtrix_sizes, b_m);
MatrixXd matrix_A = Eigen::Map<MatrixXd>(lu_working_ptr, lu_maxtrix_sizes, lu_maxtrix_sizes);
for (size_t i = 0; i < input_0_Shape[b_dim - kDimNum]; i++) {
matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1));
}
MatrixXd L = matrix_A.template triangularView<Eigen::UnitLower>();
MatrixXd U = matrix_A.template triangularView<Eigen::Upper>();
MatrixXd x = L * U;
MatrixXd result = x.lu().solve(matrix_b);
for (size_t m = 0; m < b_stride; m++) {
*(output_y + a * b_stride + m) = (T2) * (result.data() + m);
}
}
template <typename T1, typename T2>
bool LuSolveCPUKernel<T1, T2>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x0 = reinterpret_cast<T2 *>(inputs[0]->addr);
auto input_x1 = reinterpret_cast<T2 *>(inputs[1]->addr);
auto input_x2 = reinterpret_cast<int32_t *>(inputs[2]->addr);
auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
auto output_Shape = AnfAlgo::GetOutputDeviceShape(node_wpt_, 0);
size_t input0_element_num = get_element_num(input_0_Shape);
size_t input1_element_num = get_element_num(input_1_Shape);
size_t output_element_num = get_element_num(output_Shape);
std::vector<T1> input_0(input_x0, input_x0 + input0_element_num);
std::vector<T1> input_1(input_x1, input_x1 + input1_element_num);
size_t b_dims = input_0_Shape.size();
std::vector<size_t> b_dims_vector = input_0_Shape;
size_t lu_dims = input_1_Shape.size();
std::vector<size_t> lu_dims_vector = input_1_Shape;
size_t b_stride = input_0_Shape[b_dims - 1] * input_0_Shape[b_dims - 2];
size_t lu_stride = input_1_Shape[lu_dims - 1] * input_1_Shape[lu_dims - 2];
size_t pivots_stride = input_1_Shape[lu_dims - 1];
MS_EXCEPTION_IF_ZERO("b_stride", b_stride);
size_t batch_num = output_element_num / b_stride;
if (b_dims == lu_dims) {
for (size_t i = 0; i < batch_num; i++) {
T1 *b_working_ptr = input_0.data() + i * b_stride;
T1 *lu_working_ptr = input_1.data() + i * lu_stride;
int32_t *pivots_working_ptr = &input_x2[i * pivots_stride];
LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i);
}
} else {
std::vector<size_t> b_shape = b_dims_vector;
std::vector<size_t> lu_shape = lu_dims_vector;
for (size_t i = 0; i < kDimNum; i++) {
b_shape.pop_back();
lu_shape.pop_back();
}
auto output_shape = CPUKernelUtils::GetBroadcastShape(b_shape, lu_shape);
BroadcastIterator iter(b_shape, lu_shape, output_shape);
iter.SetPos(0);
for (size_t i = 0; i < batch_num; i++) {
T1 *b_working_ptr = input_0.data() + iter.GetInputPosA() * b_stride;
T1 *lu_working_ptr = input_1.data() + iter.GetInputPosB() * lu_stride;
int32_t *pivots_working_ptr = &input_x2[iter.GetInputPosB() * pivots_stride];
LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i);
iter.GenNextPos();
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_
#include <Eigen/Dense>
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
constexpr size_t kInputNum = 3;
constexpr size_t kOutputNum = 1;
namespace kernel {
template <typename T1, typename T2>
class LuSolveCPUKernel : public CPUKernel {
public:
LuSolveCPUKernel() = default;
~LuSolveCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
void LuSolve(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
T1 *b_working_ptr, T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a);
private:
CNodePtr node_wpt_;
std::vector<size_t> input_0_Shape;
std::vector<size_t> input_1_Shape;
std::vector<size_t> input_2_Shape;
};
MS_REG_CPU_KERNEL_T_S(LuSolve,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
LuSolveCPUKernel, float, float16);
MS_REG_CPU_KERNEL_T_S(LuSolve,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
LuSolveCPUKernel, float, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_

View File

@ -561,6 +561,7 @@ inline const PrimitivePtr kPrimBesselI0 = std::make_shared<Primitive>("BesselI0"
inline const PrimitivePtr kPrimBesselI1 = std::make_shared<Primitive>("BesselI1");
inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger");
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
inline const PrimitivePtr kPrimLuSolve = std::make_shared<Primitive>("LuSolve");
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);
inline const PrimitivePtr kPrimAddcdiv = std::make_shared<Primitive>(kAddcdiv);

View File

@ -0,0 +1,162 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/lu_solve_.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#define LuSolve_for(shape) \
do { \
for (auto item : shape) { \
buffer << item << " "; \
} \
} while (0)
#define LuSolve_pop() \
do { \
for (size_t i = 0; i < 2; i++) { \
x_shape.pop_back(); \
lu_data_shape.pop_back(); \
} \
} while (0)
#define LuSolve_buffer(x_shape, lu_data_shape) \
do { \
LuSolve_pop(); \
buffer << "For LuSolve x's batch dimension does not match lu_data's batch dimension, x's batch dimension is ["; \
LuSolve_for(x_shape); \
buffer << "], lu_data's batch dimension is ["; \
LuSolve_for(lu_data_shape); \
buffer << "], the batch dimensions may have different sizes, "; \
buffer << "from right to left, the corresponding dimensions must be equal."; \
} while (0)
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int64_t kDimNum = 2;
std::ostringstream buffer;
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto lu_data_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto lu_data_shape = lu_data_shape_map[kShape];
auto lu_pivots_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape());
auto lu_pivots_shape = lu_pivots_shape_map[kShape];
if (lu_data_shape.size() < kDimNum) {
MS_EXCEPTION(ValueError) << "For " << op_name << " lu_data's dimensions should be greater than or equal to 2.";
}
if (x_shape.size() < kDimNum) {
MS_EXCEPTION(ValueError) << "For " << op_name << " x's dimensions should be greater than or equal to 2.";
}
if (lu_pivots_shape.size() < 1) {
MS_EXCEPTION(ValueError) << "For " << op_name << " lu_pivots's dimensions should be greater than or equal to 1.";
}
if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum]) {
MS_EXCEPTION(ValueError) << "For " << op_name << " input lu_data should be square matrix "
<< "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is "
<< lu_data_shape[lu_data_shape.size() - 1] << ".";
}
if (x_shape[x_shape.size() - kDimNum] != lu_data_shape[lu_data_shape.size() - kDimNum]) {
MS_EXCEPTION(ValueError) << "For " << op_name << " x's col rank is not same as lu_data's col rank. "
<< "x is " << x_shape[x_shape.size() - kDimNum] << ", lu_data is "
<< lu_data_shape[lu_data_shape.size() - kDimNum] << ".";
}
if (x_shape.size() == lu_data_shape.size()) {
for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) {
if (x_shape[i] != lu_data_shape[i]) {
LuSolve_buffer(x_shape, lu_data_shape);
MS_EXCEPTION(ValueError) << buffer.str();
}
}
} else if (lu_data_shape.size() > x_shape.size()) {
for (size_t i = 0; i < x_shape.size() - kDimNum; i++) {
if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) {
LuSolve_buffer(x_shape, lu_data_shape);
MS_EXCEPTION(ValueError) << buffer.str();
}
}
} else {
for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) {
if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) {
LuSolve_buffer(x_shape, lu_data_shape);
MS_EXCEPTION(ValueError) << buffer.str();
}
}
}
if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) {
MS_EXCEPTION(ValueError) << "For " << op_name
<< " the last dimension of lu_pivots must be equal to the last dimension of lu_data, "
<< "lu_data is " << lu_data_shape[lu_data_shape.size() - 1] << ", lu_pivots is "
<< lu_pivots_shape[lu_pivots_shape.size() - 1] << ".";
}
for (size_t i = 0; i < lu_pivots_shape.size(); i++) {
if (lu_data_shape[i] != lu_pivots_shape[i]) {
x_shape.pop_back();
x_shape.pop_back();
lu_pivots_shape.pop_back();
buffer << "For " << op_name
<< " lu_data's batch dimension does not match lu_pivots's batch dimension, lu_data's batch dimension is [";
LuSolve_for(x_shape);
buffer << "], lu_pivots's batch dimension is [";
LuSolve_for(lu_pivots_shape);
buffer << "], the size of the dimension and the number of each dimension must be the same.";
MS_EXCEPTION(ValueError) << buffer.str();
}
}
auto dim_vector = lu_data_shape;
if (x_shape.size() >= lu_data_shape.size()) {
return std::make_shared<abstract::Shape>(x_shape);
} else {
dim_vector[lu_data_shape.size() - 1] = x_shape[x_shape.size() - 1];
return std::make_shared<abstract::Shape>(dim_vector);
}
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const int64_t kDimNum = 2;
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> type;
(void)type.emplace("x", input_args[0]->BuildType());
(void)type.emplace("lu_data", input_args[1]->BuildType());
const std::set<TypePtr> valid_types = {kFloat32, kFloat16};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("lu_data", input_args[1]->BuildType(), valid_types, prim->name());
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(type, valid_types, prim->name());
const std::set<TypePtr> valid_lu_pivots_types = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("lu_pivots", input_args[kDimNum]->BuildType(), valid_lu_pivots_types,
prim->name());
return out_type;
}
} // namespace
AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LuSolve, prim::kPrimLuSolve, LuSolveInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
#ifndef MINDSPORE_CORE_OPS_LUSOLVE_H_
#define MINDSPORE_CORE_OPS_LUSOLVE_H_
#include <map>
#include <set>
#include <string>
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLuSolve = "LuSolve";
class LuSolve : public PrimitiveC {
public:
LuSolve() : PrimitiveC(kNameLuSolve) { InitIOName({"x", "lu_data", "lu_pivots"}, {"output"}); }
~LuSolve() = default;
MS_DECLARE_PARENT(LuSolve, PrimitiveC);
};
AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimLuSolvePtr = std::shared_ptr<LuSolve>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LUSOLVE_H_

View File

@ -14,6 +14,7 @@
"""aicpu ops"""
from .unique import _unique_aicpu
from .lu_solve import _lu_solve_aicpu
from .no_repeat_ngram import _no_repeat_ngram_aicpu
from .init_data_set_queue import _init_data_set_queue_aicpu
from .embedding_lookup import _embedding_lookup_aicpu

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LuSolve op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
lu_solve_op_info = AiCPURegOp("LuSolve") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "lu_data", "required") \
.input(2, "lu_pivots", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(lu_solve_op_info)
def _lu_solve_aicpu():
"""LuSolve aicpu register"""
return

View File

@ -62,7 +62,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0, BesselI1, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv,
Addcmul, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose)
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -510,6 +510,7 @@ __all__ = [
"NeighborExchange",
"AlltoAll",
"Custom",
"LuSolve",
]
__sponge__ = [

View File

@ -5818,6 +5818,7 @@ class Trunc(Primitive):
def __init__(self):
"""Initialize Trunc"""
class IsClose(Primitive):
r"""
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.
@ -5874,3 +5875,55 @@ class IsClose(Primitive):
raise ValueError("For IsClose, the `equal_nan` must be True, but got False.")
validator.check_non_negative_float(rtol, 'rtol', self.name)
validator.check_non_negative_float(atol, 'atol', self.name)
class LuSolve(Primitive):
"""
Return the solution of the linear equation Ax = b.
Note:
The batch dimensions of lu_pivots must match the batch dimensions of lu_data, the size of the dimension and the
number of each dimension must be the same. For example, lu_data is (3, 3, 2, 2) lu_pivots is (3, 3, 2),
lu_data's batch dimensions is (3, 3), lu_pivots's batch dimensions is (3, 3).
The batch dimensions of lu_data must match the batch dimensions of x, the batch dimensions may have
different sizes, from right to left, the corresponding dimensions must be equal. For example, lu_data
is (3, 3, 2, 2) x is (2, 3, 3, 2, 1), lu_data's batch dimensions is (3, 3), x's batch dimensions is (2, 3, 3).
Inputs:
- **x** (Tensor) - The input is a tensor of size (*, m, k), where * is batch dimensions, with data type
float32, float16.
- **lu_data** (Tensor) - The input is a tensor of size (*, m, m), where * is batch dimensions, that can
be decomposed into an upper
triangular matrix U and a lower triangular matrix L, with data type float32, float16.
- **lu_pivots** (Tensor) - The input is a tensor of size (*, m), where * is batch dimensions, that can
be converted to a permutation matrix P, with data type int32.
Outputs:
Tensor, the same data type as the x and lu_data.
Raises:
TypeError: If dtype of `x` or `lu_data` is not one of: float32, float16.
TypeError: If dtype of `lu_pivots` is not: int32.
TypeError: If `x`, `lu_data` or `lu_pivots` is not Tensor.
TypeError: If dtype of `x` is not same as dtype of `lu_data`.
ValueError: If the batch dimensions of lu_pivots does not match the batch dimensions of lu_data.
ValueError: If `x` dimension less than 2, `lu_data` dimension less than 2 or `lu_pivots` dimension less than 1.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([[1], [3], [3]]), mindspore.float32)
>>> lu_data = Tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), mindspore.float32)
>>> lu_pivots = Tensor(np.array([2, 2, 3]), mindspore.int32)
>>> net = ops.LuSolve()
>>> y = net(x, lu_data, lu_pivots)
>>> print(y)
[[ 1.9000002]
[-1.4000001]
[ 0.6 ]]
"""
@prim_attr_register
def __init__(self):
pass

View File

@ -1543,6 +1543,12 @@ test_case_math_ops = [
'block': P.Square(),
'desc_inputs': [[4]],
'desc_bprop': [[4]]}),
('LuSolve', {
'block': P.LuSolve(),
'desc_inputs': [Tensor(np.random.rand(3, 3), mstype.float32),
Tensor(np.random.rand(3, 3), mstype.float32),
Tensor(np.random.rand(3), mstype.int32)],
'skip': ['backward']}),
('Rsqrt', {
'block': P.Rsqrt(),
'desc_inputs': [[4]],