!44860 kernelmod: lu

Merge pull request !44860 from hujiahui8/kernelmod
This commit is contained in:
i-robot 2022-11-03 02:13:58 +00:00 committed by Gitee
commit 1f35feb33a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 252 additions and 88 deletions

View File

@ -63,39 +63,53 @@ void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *
}
}
void LUCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
bool LUCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->GetPrim()->name();
dtype_ = inputs[0]->GetDtype();
size_t input_num = inputs.size();
CHECK_KERNEL_INPUTS_NUM(input_num, kLUInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = outputs.size();
CHECK_KERNEL_OUTPUTS_NUM(output_num, kLUOutputsNum, kernel_name_);
auto a_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUaIndex));
InitMatrixInfo(a_shape, &a_row_, &a_col_);
auto lu_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex));
InitMatrixInfo(lu_shape, &lu_row_, &lu_col_);
auto permutation_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, kPermutationIndex));
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
auto pivots_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex));
InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "LU does not support this kernel data type: " << kernel_attr;
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = std::get<1>(func_list_[index]);
const size_t kTwoIdx = 2;
init_io_func_ = std::get<kTwoIdx>(func_list_[index]);
kernel_func_ = func_list_[index].second;
return true;
}
template <typename T>
void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
DeprecatedNativeCpuKernelMod::InitInputOutputSize(kernel_node);
size_t lu_size = lu_col_ * sizeof(T);
int LUCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
a_row_ = 1;
a_col_ = 1;
lu_row_ = 1;
lu_col_ = 1;
permutation_row_ = 1;
permutation_col_ = 1;
pivots_row_ = 1;
pivots_col_ = 1;
auto a_shape = Convert2SizeTClipNeg(inputs[kLUaIndex]->GetShapeVector());
InitMatrixInfo(a_shape, &a_row_, &a_col_);
auto lu_shape = Convert2SizeTClipNeg(outputs[kLuIndex]->GetShapeVector());
InitMatrixInfo(lu_shape, &lu_row_, &lu_col_);
auto permutation_shape = Convert2SizeTClipNeg(outputs[kPermutationIndex]->GetShapeVector());
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
auto pivots_shape = Convert2SizeTClipNeg(outputs[kPivotsIndex]->GetShapeVector());
InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_);
size_t lu_size = lu_col_ * dtype_;
(void)workspace_size_list_.emplace_back(lu_size);
(void)workspace_size_list_.emplace_back(lu_size);
return KRET_OK;
}
template <typename T>
@ -247,25 +261,24 @@ bool LUCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
std::vector<std::tuple<KernelAttr, LUCpuKernelMod::LUFunc, LUCpuKernelMod::InitFunc>> LUCpuKernelMod::func_list_ = {
std::vector<std::pair<KernelAttr, LUCpuKernelMod::LUFunc>> LUCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&LUCpuKernelMod::LaunchKernel<float>, &LUCpuKernelMod::InitIOSize<float>},
&LUCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&LUCpuKernelMod::LaunchKernel<double>, &LUCpuKernelMod::InitIOSize<double>}};
&LUCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LUCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::tuple<KernelAttr, LUFunc, InitFunc> &tuple_item) { return std::get<0>(tuple_item); });
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LUFunc> &pair) { return pair.first; });
return support_list;
}

View File

@ -14,9 +14,11 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_LU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_LU_CPU_KERNEL_H_
#include <map>
#include <utility>
#include <tuple>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
@ -24,11 +26,16 @@
namespace mindspore {
namespace kernel {
class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class LUCpuKernelMod : public NativeCpuKernelMod {
public:
LUCpuKernelMod() = default;
~LUCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
@ -51,12 +58,8 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<kernel::AddressPtr> &outputs);
using LUFunc = std::function<bool(LUCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
using InitFunc = std::function<void(LUCpuKernelMod *, const CNodePtr &)>;
static std::vector<std::tuple<KernelAttr, LUFunc, InitFunc>> func_list_;
static std::vector<std::pair<KernelAttr, LUFunc>> func_list_;
LUFunc kernel_func_;
InitFunc init_io_func_;
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) const;
@ -75,4 +78,4 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_LU_CPU_KERNEL_H_

View File

@ -14,11 +14,12 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <map>
#include <string>
#include <algorithm>
#include <type_traits>
@ -32,7 +33,7 @@ namespace mindspore {
namespace kernel {
template <typename T>
class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class LUGpuKernelMod : public NativeGpuKernelMod {
public:
LUGpuKernelMod() : is_null_input_(false) {}
~LUGpuKernelMod() = default;
@ -59,26 +60,26 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
size_t host_transpose_shape[shape_2d] = {m_, n_};
size_t host_transpose_axis[shape_2d] = {1, 0};
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim3);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch.");
// 4. query working space of getrf
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
@ -88,24 +89,23 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
T *output_addr = batch_output_addr + batch * m_ * n_;
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
int *piv_output_addr = batch_piv_output_addr + batch * k_;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work,
reinterpret_cast<cudaStream_t>(stream_ptr));
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_,
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_,
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
@ -120,10 +120,10 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
std::vector<int> host_permuted(k_, 0);
std::vector<int> host_pivots(k_, 0);
std::vector<int> host_permutation(k_ * k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_,
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host.");
// cal pivots && permutation major by row.
for (size_t i = 0; i < k_; ++i) {
@ -139,40 +139,64 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
for (size_t i = 0; i < k_; ++i) {
host_permutation[host_permuted[i] * k_ + i] = 1;
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array.");
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
// 1. get CuSolver Dense matrix handler
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (IsDynamic(shape_signed)) {
return true;
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
batch_size_ = 1;
auto shape_signed = inputs[kIndex0]->GetShapeVector();
auto in_shape = Convert2SizeT(shape_signed);
// 2. check input shape not null
is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
return KRET_OK;
}
// 3. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
MS_LOG(ERROR) << "For 'PureCholeskyGpuKernel', input shape init failed.";
return KRET_RESIZE_FAILED;
}
return true;
return KRET_OK;
}
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
};
return support_list;
}
private:
@ -198,7 +222,7 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
return true;
}
void InitSizeLists() override {
void InitSizeLists() {
size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
input_size_list_.push_back(input_size);
@ -244,4 +268,4 @@ class LUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_

82
mindspore/core/ops/lu.cc Normal file
View File

@ -0,0 +1,82 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/lu.h"
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kLUInputsNum = 1;
constexpr size_t kXDim = 2;
constexpr size_t kLastDim = 1;
constexpr size_t kPenultimateDim = 2;
abstract::TupleShapePtr LUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto x_output = std::make_shared<abstract::Shape>(x_shape);
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_output, x_output, x_output});
}
size_t x_shape_size = x_shape.size();
if (x_shape_size < kXDim) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "',"
<< " the dimension of hashmap must be greater than or equal to 2, but got: "
<< x_shape_size << ".";
}
auto k_shape = std::min(x_shape[x_shape_size - kLastDim], x_shape[x_shape_size - kPenultimateDim]);
ShapeVector top_k_shape(x_shape.begin(), x_shape.end() - kPenultimateDim);
ShapeVector pivots_shape = top_k_shape;
pivots_shape.push_back(k_shape);
ShapeVector permutation_shape = pivots_shape;
permutation_shape.push_back(k_shape);
auto pivots_output = std::make_shared<abstract::Shape>(pivots_shape);
auto permutation_output = std::make_shared<abstract::Shape>(permutation_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{x_output, pivots_output, permutation_output});
}
TuplePtr LUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[0]->BuildType();
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, kInt32, kInt32});
}
} // namespace
AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kLUInputsNum, primitive->name());
auto infer_type = LUInferType(primitive, input_args);
auto infer_shape = LUInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(LU, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LU, prim::kPrimLU, LUInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

42
mindspore/core/ops/lu.h Normal file
View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
#ifndef MINDSPORE_CORE_OPS_LU_H_
#define MINDSPORE_CORE_OPS_LU_H_
#include <map>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLU = "LU";
class MIND_API LU : public BaseOperator {
public:
MIND_API_BASE_MEMBER(LU);
LU() : BaseOperator(kNameLU) { InitIOName({"x"}, {"lu", "pivots", "permutation"}); }
};
abstract::AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimLUPtr = std::shared_ptr<LU>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LU_H_