forked from mindspore-Ecosystem/mindspore
!23958 [assistant][ops] Add LuSolve
Merge pull request !23958 from 张渝/lu_solve
This commit is contained in:
commit
46fdab3848
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/lu_solve_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kDimNum = 2;
|
||||
}
|
||||
|
||||
size_t get_element_num(const std::vector<size_t> &shape) {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void LuSolveCPUKernel<T1, T2>::InitKernel(const CNodePtr &kernel_node) {
|
||||
node_wpt_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputNum, kernel_name_);
|
||||
auto x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto lu_data_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
auto lu_pivots_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
if (lu_data_shape.size() < kDimNum) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_data's dimensions should be greater than or equal to 2.";
|
||||
}
|
||||
if (x_shape.size() < kDimNum) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel x's dimensions should be greater than or equal to 2.";
|
||||
}
|
||||
if (lu_pivots_shape.size() < 1) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel lu_pivots's dimensions should be greater than or equal to 1.";
|
||||
}
|
||||
if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum])
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
|
||||
<< " input lu_data should be square matrix "
|
||||
<< "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is "
|
||||
<< lu_data_shape[lu_data_shape.size() - 1] << ".";
|
||||
|
||||
if (x_shape.size() == lu_data_shape.size()) {
|
||||
for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) {
|
||||
if (x_shape[i] != lu_data_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
|
||||
<< " shapes in dim[" << i << "] are not the same "
|
||||
<< "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << ".";
|
||||
}
|
||||
}
|
||||
} else if (lu_data_shape.size() > x_shape.size()) {
|
||||
for (size_t i = 0; i < x_shape.size() - kDimNum; i++) {
|
||||
if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel"
|
||||
<< " shapes in dim[" << i << "] are not same as lu_data's dim["
|
||||
<< lu_data_shape.size() - x_shape.size() + i << "]"
|
||||
<< "while x is " << x_shape[i] << ", lu_data is " << lu_data_shape[i] << ".";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) {
|
||||
if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
|
||||
<< " shapes in lu_data's dim[" << i << "] are not same as x's dim["
|
||||
<< x_shape.size() - lu_data_shape.size() + i << "]"
|
||||
<< "while x is " << x_shape[x_shape.size() - lu_data_shape.size() + i]
|
||||
<< ", lu_data is " << lu_data_shape[i] << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
|
||||
<< " Number of pivots per batch should be same as the dimension of the matrix.";
|
||||
}
|
||||
for (size_t i = 0; i < lu_pivots_shape.size(); i++) {
|
||||
if (lu_data_shape[i] != lu_pivots_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For LuSolveCPUKercel "
|
||||
<< "batch dimension of LU_pivots should match batch dimension of LU_data.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void LuSolveCPUKernel<T1, T2>::LuSolve(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, T1 *b_working_ptr,
|
||||
T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a) {
|
||||
auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
auto output_y = reinterpret_cast<T2 *>(outputs[0]->addr);
|
||||
size_t lu_dims = input_1_Shape.size();
|
||||
size_t lu_maxtrix_sizes = input_1_Shape[lu_dims - 2];
|
||||
size_t b_dim = input_0_Shape.size();
|
||||
size_t b_m = input_0_Shape[b_dim - 1];
|
||||
typedef Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
|
||||
MatrixXd matrix_b = Eigen::Map<MatrixXd>(b_working_ptr, lu_maxtrix_sizes, b_m);
|
||||
MatrixXd matrix_A = Eigen::Map<MatrixXd>(lu_working_ptr, lu_maxtrix_sizes, lu_maxtrix_sizes);
|
||||
for (size_t i = 0; i < input_0_Shape[b_dim - kDimNum]; i++) {
|
||||
matrix_b.row(i).swap(matrix_b.row(*(pivots_working_ptr + i) - 1));
|
||||
}
|
||||
MatrixXd L = matrix_A.template triangularView<Eigen::UnitLower>();
|
||||
MatrixXd U = matrix_A.template triangularView<Eigen::Upper>();
|
||||
MatrixXd x = L * U;
|
||||
MatrixXd result = x.lu().solve(matrix_b);
|
||||
for (size_t m = 0; m < b_stride; m++) {
|
||||
*(output_y + a * b_stride + m) = (T2) * (result.data() + m);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool LuSolveCPUKernel<T1, T2>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x0 = reinterpret_cast<T2 *>(inputs[0]->addr);
|
||||
auto input_x1 = reinterpret_cast<T2 *>(inputs[1]->addr);
|
||||
auto input_x2 = reinterpret_cast<int32_t *>(inputs[2]->addr);
|
||||
auto input_0_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto input_1_Shape = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
auto output_Shape = AnfAlgo::GetOutputDeviceShape(node_wpt_, 0);
|
||||
size_t input0_element_num = get_element_num(input_0_Shape);
|
||||
size_t input1_element_num = get_element_num(input_1_Shape);
|
||||
size_t output_element_num = get_element_num(output_Shape);
|
||||
std::vector<T1> input_0(input_x0, input_x0 + input0_element_num);
|
||||
std::vector<T1> input_1(input_x1, input_x1 + input1_element_num);
|
||||
size_t b_dims = input_0_Shape.size();
|
||||
std::vector<size_t> b_dims_vector = input_0_Shape;
|
||||
size_t lu_dims = input_1_Shape.size();
|
||||
std::vector<size_t> lu_dims_vector = input_1_Shape;
|
||||
size_t b_stride = input_0_Shape[b_dims - 1] * input_0_Shape[b_dims - 2];
|
||||
size_t lu_stride = input_1_Shape[lu_dims - 1] * input_1_Shape[lu_dims - 2];
|
||||
size_t pivots_stride = input_1_Shape[lu_dims - 1];
|
||||
MS_EXCEPTION_IF_ZERO("b_stride", b_stride);
|
||||
size_t batch_num = output_element_num / b_stride;
|
||||
if (b_dims == lu_dims) {
|
||||
for (size_t i = 0; i < batch_num; i++) {
|
||||
T1 *b_working_ptr = input_0.data() + i * b_stride;
|
||||
T1 *lu_working_ptr = input_1.data() + i * lu_stride;
|
||||
int32_t *pivots_working_ptr = &input_x2[i * pivots_stride];
|
||||
LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i);
|
||||
}
|
||||
} else {
|
||||
std::vector<size_t> b_shape = b_dims_vector;
|
||||
std::vector<size_t> lu_shape = lu_dims_vector;
|
||||
for (size_t i = 0; i < kDimNum; i++) {
|
||||
b_shape.pop_back();
|
||||
lu_shape.pop_back();
|
||||
}
|
||||
auto output_shape = CPUKernelUtils::GetBroadcastShape(b_shape, lu_shape);
|
||||
BroadcastIterator iter(b_shape, lu_shape, output_shape);
|
||||
iter.SetPos(0);
|
||||
for (size_t i = 0; i < batch_num; i++) {
|
||||
T1 *b_working_ptr = input_0.data() + iter.GetInputPosA() * b_stride;
|
||||
T1 *lu_working_ptr = input_1.data() + iter.GetInputPosB() * lu_stride;
|
||||
int32_t *pivots_working_ptr = &input_x2[iter.GetInputPosB() * pivots_stride];
|
||||
LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_
|
||||
#include <Eigen/Dense>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr size_t kInputNum = 3;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
namespace kernel {
|
||||
template <typename T1, typename T2>
|
||||
class LuSolveCPUKernel : public CPUKernel {
|
||||
public:
|
||||
LuSolveCPUKernel() = default;
|
||||
~LuSolveCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
void LuSolve(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
|
||||
T1 *b_working_ptr, T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a);
|
||||
|
||||
private:
|
||||
CNodePtr node_wpt_;
|
||||
std::vector<size_t> input_0_Shape;
|
||||
std::vector<size_t> input_1_Shape;
|
||||
std::vector<size_t> input_2_Shape;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(LuSolve,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LuSolveCPUKernel, float, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(LuSolve,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LuSolveCPUKernel, float, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LUSOLVE_CPU_KERNEL_H_
|
|
@ -561,6 +561,7 @@ inline const PrimitivePtr kPrimBesselI0 = std::make_shared<Primitive>("BesselI0"
|
|||
inline const PrimitivePtr kPrimBesselI1 = std::make_shared<Primitive>("BesselI1");
|
||||
inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger");
|
||||
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
|
||||
inline const PrimitivePtr kPrimLuSolve = std::make_shared<Primitive>("LuSolve");
|
||||
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
||||
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);
|
||||
inline const PrimitivePtr kPrimAddcdiv = std::make_shared<Primitive>(kAddcdiv);
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/lu_solve_.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
#define LuSolve_for(shape) \
|
||||
do { \
|
||||
for (auto item : shape) { \
|
||||
buffer << item << " "; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LuSolve_pop() \
|
||||
do { \
|
||||
for (size_t i = 0; i < 2; i++) { \
|
||||
x_shape.pop_back(); \
|
||||
lu_data_shape.pop_back(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LuSolve_buffer(x_shape, lu_data_shape) \
|
||||
do { \
|
||||
LuSolve_pop(); \
|
||||
buffer << "For LuSolve x's batch dimension does not match lu_data's batch dimension, x's batch dimension is ["; \
|
||||
LuSolve_for(x_shape); \
|
||||
buffer << "], lu_data's batch dimension is ["; \
|
||||
LuSolve_for(lu_data_shape); \
|
||||
buffer << "], the batch dimensions may have different sizes, "; \
|
||||
buffer << "from right to left, the corresponding dimensions must be equal."; \
|
||||
} while (0)
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t kDimNum = 2;
|
||||
std::ostringstream buffer;
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
auto lu_data_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto lu_data_shape = lu_data_shape_map[kShape];
|
||||
auto lu_pivots_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape());
|
||||
auto lu_pivots_shape = lu_pivots_shape_map[kShape];
|
||||
if (lu_data_shape.size() < kDimNum) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << " lu_data's dimensions should be greater than or equal to 2.";
|
||||
}
|
||||
if (x_shape.size() < kDimNum) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << " x's dimensions should be greater than or equal to 2.";
|
||||
}
|
||||
if (lu_pivots_shape.size() < 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << " lu_pivots's dimensions should be greater than or equal to 1.";
|
||||
}
|
||||
if (lu_data_shape[lu_data_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - kDimNum]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << " input lu_data should be square matrix "
|
||||
<< "while row is " << lu_data_shape[lu_data_shape.size() - kDimNum] << ", col is "
|
||||
<< lu_data_shape[lu_data_shape.size() - 1] << ".";
|
||||
}
|
||||
if (x_shape[x_shape.size() - kDimNum] != lu_data_shape[lu_data_shape.size() - kDimNum]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << " x's col rank is not same as lu_data's col rank. "
|
||||
<< "x is " << x_shape[x_shape.size() - kDimNum] << ", lu_data is "
|
||||
<< lu_data_shape[lu_data_shape.size() - kDimNum] << ".";
|
||||
}
|
||||
if (x_shape.size() == lu_data_shape.size()) {
|
||||
for (size_t i = 0; i <= x_shape.size() - kDimNum; i++) {
|
||||
if (x_shape[i] != lu_data_shape[i]) {
|
||||
LuSolve_buffer(x_shape, lu_data_shape);
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
}
|
||||
} else if (lu_data_shape.size() > x_shape.size()) {
|
||||
for (size_t i = 0; i < x_shape.size() - kDimNum; i++) {
|
||||
if (x_shape[i] != lu_data_shape[lu_data_shape.size() - x_shape.size() + i]) {
|
||||
LuSolve_buffer(x_shape, lu_data_shape);
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < lu_data_shape.size() - kDimNum; i++) {
|
||||
if (lu_data_shape[i] != x_shape[x_shape.size() - lu_data_shape.size() + i]) {
|
||||
LuSolve_buffer(x_shape, lu_data_shape);
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lu_pivots_shape[lu_pivots_shape.size() - 1] != lu_data_shape[lu_data_shape.size() - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name
|
||||
<< " the last dimension of lu_pivots must be equal to the last dimension of lu_data, "
|
||||
<< "lu_data is " << lu_data_shape[lu_data_shape.size() - 1] << ", lu_pivots is "
|
||||
<< lu_pivots_shape[lu_pivots_shape.size() - 1] << ".";
|
||||
}
|
||||
for (size_t i = 0; i < lu_pivots_shape.size(); i++) {
|
||||
if (lu_data_shape[i] != lu_pivots_shape[i]) {
|
||||
x_shape.pop_back();
|
||||
x_shape.pop_back();
|
||||
lu_pivots_shape.pop_back();
|
||||
buffer << "For " << op_name
|
||||
<< " lu_data's batch dimension does not match lu_pivots's batch dimension, lu_data's batch dimension is [";
|
||||
LuSolve_for(x_shape);
|
||||
buffer << "], lu_pivots's batch dimension is [";
|
||||
LuSolve_for(lu_pivots_shape);
|
||||
buffer << "], the size of the dimension and the number of each dimension must be the same.";
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
}
|
||||
auto dim_vector = lu_data_shape;
|
||||
if (x_shape.size() >= lu_data_shape.size()) {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
} else {
|
||||
dim_vector[lu_data_shape.size() - 1] = x_shape[x_shape.size() - 1];
|
||||
return std::make_shared<abstract::Shape>(dim_vector);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kDimNum = 2;
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> type;
|
||||
(void)type.emplace("x", input_args[0]->BuildType());
|
||||
(void)type.emplace("lu_data", input_args[1]->BuildType());
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat16};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("lu_data", input_args[1]->BuildType(), valid_types, prim->name());
|
||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(type, valid_types, prim->name());
|
||||
const std::set<TypePtr> valid_lu_pivots_types = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("lu_pivots", input_args[kDimNum]->BuildType(), valid_lu_pivots_types,
|
||||
prim->name());
|
||||
return out_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LuSolve, prim::kPrimLuSolve, LuSolveInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LUSOLVE_H_
|
||||
#define MINDSPORE_CORE_OPS_LUSOLVE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameLuSolve = "LuSolve";
|
||||
class LuSolve : public PrimitiveC {
|
||||
public:
|
||||
LuSolve() : PrimitiveC(kNameLuSolve) { InitIOName({"x", "lu_data", "lu_pivots"}, {"output"}); }
|
||||
~LuSolve() = default;
|
||||
MS_DECLARE_PARENT(LuSolve, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimLuSolvePtr = std::shared_ptr<LuSolve>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_LUSOLVE_H_
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
"""aicpu ops"""
|
||||
from .unique import _unique_aicpu
|
||||
from .lu_solve import _lu_solve_aicpu
|
||||
from .no_repeat_ngram import _no_repeat_ngram_aicpu
|
||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||
from .embedding_lookup import _embedding_lookup_aicpu
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LuSolve op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
lu_solve_op_info = AiCPURegOp("LuSolve") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "lu_data", "required") \
|
||||
.input(2, "lu_pivots", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(lu_solve_op_info)
|
||||
def _lu_solve_aicpu():
|
||||
"""LuSolve aicpu register"""
|
||||
return
|
|
@ -62,7 +62,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0, BesselI1, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv,
|
||||
Addcmul, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
|
||||
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose)
|
||||
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve)
|
||||
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
|
@ -510,6 +510,7 @@ __all__ = [
|
|||
"NeighborExchange",
|
||||
"AlltoAll",
|
||||
"Custom",
|
||||
"LuSolve",
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -5818,6 +5818,7 @@ class Trunc(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize Trunc"""
|
||||
|
||||
|
||||
class IsClose(Primitive):
|
||||
r"""
|
||||
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.
|
||||
|
@ -5874,3 +5875,55 @@ class IsClose(Primitive):
|
|||
raise ValueError("For IsClose, the `equal_nan` must be True, but got False.")
|
||||
validator.check_non_negative_float(rtol, 'rtol', self.name)
|
||||
validator.check_non_negative_float(atol, 'atol', self.name)
|
||||
|
||||
|
||||
class LuSolve(Primitive):
|
||||
"""
|
||||
Return the solution of the linear equation Ax = b.
|
||||
|
||||
Note:
|
||||
The batch dimensions of lu_pivots must match the batch dimensions of lu_data, the size of the dimension and the
|
||||
number of each dimension must be the same. For example, lu_data is (3, 3, 2, 2) lu_pivots is (3, 3, 2),
|
||||
lu_data's batch dimensions is (3, 3), lu_pivots's batch dimensions is (3, 3).
|
||||
|
||||
The batch dimensions of lu_data must match the batch dimensions of x, the batch dimensions may have
|
||||
different sizes, from right to left, the corresponding dimensions must be equal. For example, lu_data
|
||||
is (3, 3, 2, 2) x is (2, 3, 3, 2, 1), lu_data's batch dimensions is (3, 3), x's batch dimensions is (2, 3, 3).
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input is a tensor of size (*, m, k), where * is batch dimensions, with data type
|
||||
float32, float16.
|
||||
- **lu_data** (Tensor) - The input is a tensor of size (*, m, m), where * is batch dimensions, that can
|
||||
be decomposed into an upper
|
||||
triangular matrix U and a lower triangular matrix L, with data type float32, float16.
|
||||
- **lu_pivots** (Tensor) - The input is a tensor of size (*, m), where * is batch dimensions, that can
|
||||
be converted to a permutation matrix P, with data type int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, the same data type as the x and lu_data.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` or `lu_data` is not one of: float32, float16.
|
||||
TypeError: If dtype of `lu_pivots` is not: int32.
|
||||
TypeError: If `x`, `lu_data` or `lu_pivots` is not Tensor.
|
||||
TypeError: If dtype of `x` is not same as dtype of `lu_data`.
|
||||
ValueError: If the batch dimensions of lu_pivots does not match the batch dimensions of lu_data.
|
||||
ValueError: If `x` dimension less than 2, `lu_data` dimension less than 2 or `lu_pivots` dimension less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1], [3], [3]]), mindspore.float32)
|
||||
>>> lu_data = Tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), mindspore.float32)
|
||||
>>> lu_pivots = Tensor(np.array([2, 2, 3]), mindspore.int32)
|
||||
>>> net = ops.LuSolve()
|
||||
>>> y = net(x, lu_data, lu_pivots)
|
||||
>>> print(y)
|
||||
[[ 1.9000002]
|
||||
[-1.4000001]
|
||||
[ 0.6 ]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
|
@ -1543,6 +1543,12 @@ test_case_math_ops = [
|
|||
'block': P.Square(),
|
||||
'desc_inputs': [[4]],
|
||||
'desc_bprop': [[4]]}),
|
||||
('LuSolve', {
|
||||
'block': P.LuSolve(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 3), mstype.float32),
|
||||
Tensor(np.random.rand(3, 3), mstype.float32),
|
||||
Tensor(np.random.rand(3), mstype.int32)],
|
||||
'skip': ['backward']}),
|
||||
('Rsqrt', {
|
||||
'block': P.Rsqrt(),
|
||||
'desc_inputs': [[4]],
|
||||
|
|
Loading…
Reference in New Issue