add dynamic support for solve triangular

This commit is contained in:
zhujingxuan 2022-10-12 16:18:00 +08:00
parent e24f136909
commit 0bdfd3e052
12 changed files with 671 additions and 513 deletions

View File

@ -1,166 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;
using Eigen::Dynamic;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::UnitLower;
using Eigen::UnitUpper;
using Eigen::Upper;
template <typename T, int Major>
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
constexpr auto kSolveTriangularInputsNum = 2;
constexpr auto kSolveTriangularOutputsNum = 1;
constexpr auto kAVectorxDimNum = 1;
constexpr auto kAMatrixDimNum = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
void MatrixTriangularSolveCpuKernelMod::InitShape(const CNodePtr &kernel_node) {
auto a_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0));
auto b_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1));
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
size_t a_dims = a_shape.size();
size_t aRowIndex = a_dims - kRowIndex;
m_ = a_shape[aRowIndex];
size_t b_sims = b_shape.size();
bool vector_b = b_sims == a_dims - 1;
if (vector_b) {
n_ = 1;
} else {
n_ = b_shape[b_sims - 1];
}
batch_ = 1;
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
batch_ *= a_shape[batch];
}
}
void MatrixTriangularSolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
InitShape(kernel_node);
if (common::AnfAlgo::HasNodeAttr(ADJOINT, kernel_node)) {
// MatrixTriangularSolve attribute
trans_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT);
if (common::AnfAlgo::HasNodeAttr(TRANS, kernel_node)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
}
} else {
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
unit_diagonal_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL);
const std::string trans = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS);
if (trans == "N") {
trans_ = false;
} else if (trans == "T") {
trans_ = true;
} else if (trans == "C") {
trans_ = true;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans
<< "].";
}
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "SolveTriangular does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename Derived_a, typename Derived_b, typename T>
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
bool lower, bool unit_diagonal) {
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
if (unit_diagonal) {
if (lower) {
output.noalias() = a.template triangularView<UnitLower>().solve(b);
} else {
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
}
} else {
if (lower) {
output.noalias() = a.template triangularView<Lower>().solve(b);
} else {
output.noalias() = a.template triangularView<Upper>().solve(b);
}
}
}
template <typename T>
bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t a_batch_size = m_ * m_;
size_t b_batch_size = m_ * n_;
size_t output_batch_size = m_ * n_;
for (size_t i = 0; i < batch_; ++i) {
T *a_batch_addr = a_addr + i * a_batch_size;
T *b_batch_addr = b_addr + i * b_batch_size;
T *output_batch_addr = output_addr + i * output_batch_size;
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
if (trans_) {
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
solve(a, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
} else {
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
}
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixTriangularSolveCpuKernelMod::MatrixTriangularSolveFunc>>
MatrixTriangularSolveCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MatrixTriangularSolveCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixTriangularSolveFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SolveTriangular, MatrixTriangularSolveCpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,151 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include <map>
#include "mindspore/core/ops/solve_triangular.h"
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;
using Eigen::Dynamic;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::UnitLower;
using Eigen::UnitUpper;
using Eigen::Upper;
template <typename T, int Major>
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
using KernelRunFunc = SolveTriangularCpuKernelMod::KernelRunFunc;
constexpr auto kSolveTriangularInputsNum = 2;
constexpr auto kSolveTriangularOutputsNum = 1;
int SolveTriangularCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto a_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
auto b_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
size_t a_dims = a_shape.size();
size_t b_dims = b_shape.size();
m_ = a_shape[a_dims - kIndex2];
n_ = (b_dims == a_dims - 1) ? 1 : b_shape[b_dims - 1];
batch_ = std::accumulate(a_shape.begin(), a_shape.end() - kIndex2, int64_t(1), std::multiplies{});
return KRET_OK;
}
bool SolveTriangularCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_ptr = std::make_shared<ops::SolveTriangular>(base_operator->GetPrim());
lower_ = kernel_ptr->get_lower();
unit_diagonal_ = kernel_ptr->get_unit_diagonal();
const std::string trans = kernel_ptr->get_trans();
if (trans == "N") {
trans_ = false;
} else if (trans == "T") {
trans_ = true;
} else if (trans == "C") {
// currently does not support complex.
trans_ = true;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans << "].";
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
template <typename Derived_a, typename Derived_b, typename T>
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
bool lower, bool unit_diagonal) {
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
if (unit_diagonal) {
if (lower) {
output.noalias() = a.template triangularView<UnitLower>().solve(b);
} else {
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
}
} else {
if (lower) {
output.noalias() = a.template triangularView<Lower>().solve(b);
} else {
output.noalias() = a.template triangularView<Upper>().solve(b);
}
}
}
template <typename T>
bool SolveTriangularCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t a_batch_size = m_ * m_;
size_t b_batch_size = m_ * n_;
size_t output_batch_size = m_ * n_;
for (size_t i = 0; i < batch_; ++i) {
T *a_batch_addr = a_addr + i * a_batch_size;
T *b_batch_addr = b_addr + i * b_batch_size;
T *output_batch_addr = output_addr + i * output_batch_size;
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
if (trans_) {
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
solve(a, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
} else {
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
}
}
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SolveTriangularCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SolveTriangularCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SolveTriangularCpuKernelMod::LaunchKernel<double>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SolveTriangular, SolveTriangularCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,35 +19,36 @@
#include <vector>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SolveTriangularCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<SolveTriangularCpuKernelMod> {
public:
MatrixTriangularSolveCpuKernelMod() = default;
~MatrixTriangularSolveCpuKernelMod() override = default;
SolveTriangularCpuKernelMod() = default;
~SolveTriangularCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
std::vector<KernelAttr> GetOpSupport() override;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void InitShape(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using MatrixTriangularSolveFunc =
std::function<bool(MatrixTriangularSolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixTriangularSolveFunc>> func_list_;
MatrixTriangularSolveFunc kernel_func_;
size_t m_{0};
size_t n_{0};

View File

@ -1,30 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatrixTriangularSolveGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MatrixTriangularSolveGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -1,259 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include <vector>
#include <string>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "plugin/device/gpu/kernel/gpu_kernel_utils.h"
namespace mindspore {
namespace kernel {
constexpr auto kAVectorxDimNum = 1;
constexpr auto kAMatrixDimNum = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
constexpr size_t kShape3D = 3;
constexpr size_t kIndexAArray = 0;
constexpr size_t kIndexDstArray = 1;
constexpr size_t kIndexBBuffer = 2;
constexpr size_t kIndexBTransposeShape = 3;
constexpr size_t kIndexBTransposeAxis = 4;
template <typename T>
class MatrixTriangularSolveGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public:
MatrixTriangularSolveGpuKernelMod() = default;
~MatrixTriangularSolveGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
auto inputa_addr = GetDeviceAddress<T>(inputs, 0);
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
// if b is not a vector, solve b in the workspace
T *dst = nullptr;
if (n_ == 1) {
dst = output_addr;
} else {
dst = GetDeviceAddress<T>(workspace, kIndexBBuffer);
}
const size_t batched_b_size = batch_ * m_ * n_;
if (n_ == 1) {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dst, inputb_addr, batched_b_size * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync dst failed");
} else {
// No matter how many batch dimensions the batched matrix b has, use their cumulative multiplication batch.
// In order to convert row major matrix b(batch, m, n) to col major matrix b'(batch, m, n),
// the following operation is equivalent to:
// b' = b.tarnspose(batch, n, m).reshape(batch, m, n)
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
MatrixTransposeND(inputb_addr, {batch_, m_, n_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape,
dev_transpose_b_axis, dst, reinterpret_cast<cudaStream_t>(stream_ptr), kernel_name_);
}
// index calculation
auto device_a_array_addr = GetDeviceAddress<T *>(workspace, kIndexAArray);
auto device_dst_array_addr = GetDeviceAddress<T *>(workspace, kIndexDstArray);
for (size_t i = 0; i < batch_; i++) {
host_a_array_[i] = inputa_addr + i * m_ * m_;
host_dst_array_[i] = dst + i * m_ * n_;
}
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(device_a_array_addr, host_a_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(device_dst_array_addr, host_dst_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
T alpha = 1;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
"cublas trsm Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasDtrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
"cublas trsm Fail");
}
// if x is not a vector, do transpose
if (n_ != 1) {
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
// the following operation is equivalent to:
// x = x'.reshape(n, m).T
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
MatrixTransposeND(dst, {batch_, n_, m_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
output_addr, reinterpret_cast<cudaStream_t>(stream_ptr), kernel_name_);
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
InitShape(kernel_node);
if (is_dynamic_) {
return true;
}
if (is_null_input_) {
InitSizeLists();
return true;
}
lda_ = SizeToInt(m_);
ldb_ = SizeToInt(m_);
if (common::AnfAlgo::HasNodeAttr("adjoint", kernel_node)) {
// MatrixTriangularSolve attribute
bool trans = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "adjoint");
// converting row major to col major is the same as reverting the trans flag
trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T;
if (common::AnfAlgo::HasNodeAttr("trans", kernel_node)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
}
} else {
bool lower = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
// reverting the trans flag by default, so also flip the lower flag
lower = !lower;
uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
bool unit_diagonal = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal");
unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
const std::string trans = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
// converting row major to col major is the same as reverting the trans flag
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "].";
}
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
size_t unit_size = sizeof(T);
size_t a_size = batch_ * m_ * m_ * unit_size;
size_t b_size = batch_ * m_ * n_ * unit_size;
input_size_list_ = {a_size, b_size};
output_size_list_ = {b_size};
if (n_ != 1) {
workspace_size_list_ = {
// workspace for batched a
batch_ * sizeof(T *),
// workspace for batched b
batch_ * sizeof(T *),
// workspace for transposed b
b_size,
// workspace for b transpose shape
kShape3D * sizeof(size_t *),
// workspace for b transpose axis
kShape3D * sizeof(size_t *),
};
} else {
workspace_size_list_ = {
// workspace for batched a
batch_ * sizeof(T *),
// workspace for batched b
batch_ * sizeof(T *),
};
}
}
void InitShape(const CNodePtr &kernel_node) {
is_dynamic_ = false;
auto a_shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto b_shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (AnfAlgo::IsShapesDynamic({a_shape_signed, b_shape_signed})) {
is_dynamic_ = true;
return;
}
auto a_shape = Convert2SizeTClipNeg(a_shape_signed);
auto b_shape = Convert2SizeTClipNeg(b_shape_signed);
is_null_input_ =
CHECK_SHAPE_NULL(a_shape, kernel_name_, "input_a") || CHECK_SHAPE_NULL(b_shape, kernel_name_, "input_b");
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
size_t a_dims = a_shape.size();
size_t aRowIndex = a_dims - kRowIndex;
m_ = a_shape[aRowIndex];
size_t b_sims = b_shape.size();
bool vector_b = b_sims == a_dims - 1;
if (vector_b) {
n_ = 1;
} else {
n_ = b_shape[b_sims - 1];
}
batch_ = 1;
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
batch_ *= a_shape[batch];
}
host_a_array_.resize(batch_);
host_dst_array_.resize(batch_);
}
private:
size_t m_{0};
size_t n_{0};
size_t batch_{1};
int lda_{0};
int ldb_{0};
bool is_null_input_{false};
bool is_dynamic_{false};
std::vector<T *> host_a_array_;
std::vector<T *> host_dst_array_;
cublasHandle_t blas_handle_{nullptr};
cublasFillMode_t uplo_{CUBLAS_FILL_MODE_UPPER};
cublasOperation_t trans_{CUBLAS_OP_N};
cublasDiagType_t unit_diagonal_{CUBLAS_DIAG_NON_UNIT};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_

View File

@ -0,0 +1,199 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h"
#include <map>
#include <utility>
#include <memory>
#include "mindspore/core/ops/solve_triangular.h"
namespace mindspore {
namespace kernel {
using KernelRunFunc = SolveTriangularGpuKernelMod::KernelRunFunc;
bool SolveTriangularGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto kernel_ptr = std::make_shared<ops::SolveTriangular>(base_operator->GetPrim());
bool lower = kernel_ptr->get_lower();
// reverting the trans flag by default, so also flip the lower flag
lower = !lower;
uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
bool unit_diagonal = kernel_ptr->get_unit_diagonal();
unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
const std::string trans = kernel_ptr->get_trans();
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
// currently does not support complex.
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans << "].";
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int SolveTriangularGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto a_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
auto b_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
is_null_input_ =
CHECK_SHAPE_NULL(a_shape, kernel_name_, "input_a") || CHECK_SHAPE_NULL(b_shape, kernel_name_, "input_b");
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
size_t a_dims = a_shape.size();
size_t b_dims = b_shape.size();
m_ = a_shape[a_dims - kIndex2];
n_ = (b_dims == a_dims - 1) ? 1 : b_shape[b_dims - 1];
batch_ = std::accumulate(a_shape.begin(), a_shape.end() - kIndex2, int64_t(1), std::multiplies{});
lda_ = SizeToInt(m_);
ldb_ = SizeToInt(m_);
const size_t unit_size = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
constexpr size_t pointer_size = sizeof(float *);
size_t b_size = batch_ * m_ * n_ * unit_size;
workspace_size_list_.clear();
if (n_ != 1) {
workspace_size_list_ = {
// workspace for batched a
batch_ * pointer_size,
// workspace for batched b
batch_ * pointer_size,
// workspace for transposed b
b_size,
// workspace for b transpose shape
kShape3D * sizeof(size_t *),
// workspace for b transpose axis
kShape3D * sizeof(size_t *),
};
} else {
workspace_size_list_ = {
// workspace for batched a
batch_ * pointer_size,
// workspace for batched b
batch_ * pointer_size,
};
}
return KRET_OK;
}
template <typename T>
bool SolveTriangularGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, cuda_stream_), "cublasSetStream failed");
auto inputa_addr = GetDeviceAddress<T>(inputs, 0);
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
std::vector<T *> host_a_array(batch_);
std::vector<T *> host_dst_array(batch_);
// if b is not a vector, solve b in the workspace
T *dst = nullptr;
if (n_ == 1) {
dst = output_addr;
} else {
dst = GetDeviceAddress<T>(workspace, kIndexBBuffer);
}
const size_t batched_b_size = batch_ * m_ * n_;
if (n_ == 1) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dst, inputb_addr, batched_b_size * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream_),
"cudaMemcpyAsync dst failed");
} else {
// No matter how many batch dimensions the batched matrix b has, use their cumulative multiplication batch.
// In order to convert row major matrix b(batch, m, n) to col major matrix b'(batch, m, n),
// the following operation is equivalent to:
// b' = b.tarnspose(batch, n, m).reshape(batch, m, n)
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
MatrixTransposeND(inputb_addr, {batch_, m_, n_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
dst, cuda_stream_, kernel_name_);
}
// index calculation
auto device_a_array_addr = GetDeviceAddress<T *>(workspace, kIndexAArray);
auto device_dst_array_addr = GetDeviceAddress<T *>(workspace, kIndexDstArray);
for (size_t i = 0; i < batch_; i++) {
host_a_array[i] = inputa_addr + i * m_ * m_;
host_dst_array[i] = dst + i * m_ * n_;
}
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(device_a_array_addr, host_a_array.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, cuda_stream_),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(device_dst_array_addr, host_dst_array.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, cuda_stream_),
"cuda memcopy Fail");
T alpha = 1;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
"cublas trsm Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasDtrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
"cublas trsm Fail");
}
// if x is not a vector, do transpose
if (n_ != 1) {
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
// the following operation is equivalent to:
// x = x'.reshape(n, m).T
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
MatrixTransposeND(dst, {batch_, n_, m_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
output_addr, cuda_stream_, kernel_name_);
}
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SolveTriangularGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SolveTriangularGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SolveTriangularGpuKernelMod::LaunchKernel<double>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SolveTriangular, SolveTriangularGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,90 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include <vector>
#include <string>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "plugin/device/gpu/kernel/gpu_kernel_utils.h"
namespace mindspore {
namespace kernel {
constexpr auto kAVectorxDimNum = 1;
constexpr auto kAMatrixDimNum = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
constexpr size_t kShape3D = 3;
constexpr size_t kIndexAArray = 0;
constexpr size_t kIndexDstArray = 1;
constexpr size_t kIndexBBuffer = 2;
constexpr size_t kIndexBTransposeShape = 3;
constexpr size_t kIndexBTransposeAxis = 4;
class SolveTriangularGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<SolveTriangularGpuKernelMod> {
public:
SolveTriangularGpuKernelMod() = default;
~SolveTriangularGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
size_t m_{0};
size_t n_{0};
size_t batch_{1};
int lda_{0};
int ldb_{0};
bool is_null_input_{false};
cublasHandle_t blas_handle_{nullptr};
cublasFillMode_t uplo_{CUBLAS_FILL_MODE_UPPER};
cublasOperation_t trans_{CUBLAS_OP_N};
cublasDiagType_t unit_diagonal_{CUBLAS_DIAG_NON_UNIT};
cudaStream_t cuda_stream_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_

View File

@ -124,6 +124,7 @@ constexpr auto kKSize = "ksize";
constexpr auto kKsizes = "ksizes";
constexpr auto kKernelType = "kernel_type";
constexpr auto kLimit = "limit";
constexpr auto kLower = "lower";
constexpr auto kMagSquare = "mag_square";
constexpr auto kMax = "max";
constexpr auto kMaxSizes = "max_sizes";
@ -228,10 +229,12 @@ constexpr auto kSummarize = "summarize";
constexpr auto kTimeMajor = "time_major";
constexpr auto kTolerance = "tolerance";
constexpr auto kTopK = "top_k";
constexpr auto kTrans = "trans";
constexpr auto kTransposeA = "transpose_a";
constexpr auto kTransposeB = "transpose_b";
constexpr auto kNegativeSlope = "negative_slope";
constexpr auto kType = "type";
constexpr auto kUnitDiagonal = "unit_diagonal";
constexpr auto kUpdateSlots = "update_slots";
constexpr auto kUseAxis = "use_axis";
constexpr auto kUseLocking = "use_locking";

View File

@ -0,0 +1,134 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/solve_triangular.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SolveTriangularInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto a_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(a_shape_ptr);
auto b_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(b_shape_ptr);
if (a_shape_ptr->IsDimUnknown() || b_shape_ptr->IsDimUnknown()) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny});
}
if (a_shape_ptr->IsDynamic() || b_shape_ptr->IsDynamic()) {
return b_shape_ptr->cast<abstract::ShapePtr>();
}
auto a_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(a_shape_ptr)[kShape];
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(b_shape_ptr)[kShape];
constexpr size_t square_size = 2;
const size_t expected_b_dim = (b_shape.size() == a_shape.size() - 1) ? 1 : square_size;
size_t a_dim = a_shape.size();
size_t b_dim = b_shape.size();
CheckAndConvertUtils::CheckValue<size_t>("dim of matrix a", a_dim, kGreaterEqual, square_size, prim_name);
CheckAndConvertUtils::CheckValue<size_t>("dim of matrix b", b_dim, kGreaterEqual, expected_b_dim, prim_name);
if ((a_dim != b_dim) && (a_dim - 1 != b_dim)) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of `b` should be 'a.dim' or 'a.dim' - 1, which is " << a_dim << " or "
<< (a_dim - 1) << ", but got " << b_dim << " dimensions.";
}
if (a_shape[a_dim - 1] != a_shape[a_dim - square_size]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the last two dimensions of `a` should be the same, but got shape of " << a_shape
<< ". Please make sure that the shape of `a` be like [..., N, N].";
}
if (a_shape[a_dim - square_size] != b_shape[b_dim - expected_b_dim]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the last two dimensions of `a` and `b` should be matched, but got shape of "
<< a_shape << " and " << b_shape
<< ". Please make sure that the shape of `a` and `b` be like [..., N, N] X [..., N, M] or "
"[..., N, N] X [..., N].";
}
if (!std::equal(a_shape.begin(), a_shape.begin() + (a_dim - square_size), b_shape.begin(),
b_shape.begin() + (b_dim - expected_b_dim))) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the batch dimensions of `a` and `b` should all be the same, but got shape of "
<< a_shape << " and " << b_shape
<< ". Please make sure that the shape of `a` and `b` be like [a, b, c, ..., N, N] X [a, "
"b, c, ..., N, M] or [a, b, c, ..., N, N] X [a, b, c, ..., N].";
}
return b_shape_ptr->cast<abstract::ShapePtr>();
}
TypePtr SolveTriangularInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto a_dtype = input_args[kInputIndex0]->BuildType();
auto b_dtype = input_args[kInputIndex1]->BuildType();
const std::map<std::string, TypePtr> type_dict = {{"a type", a_dtype}, {"b type", b_dtype}};
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, {kFloat32, kFloat64}, prim_name);
}
} // namespace
void SolveTriangular::Init(bool lower, bool unit_diagonal, std::string trans) { set_unit_diagonal(unit_diagonal); }
void SolveTriangular::set_unit_diagonal(bool unit_diagonal) {
(void)AddAttr(kUnitDiagonal, api::MakeValue(unit_diagonal));
}
bool SolveTriangular::get_unit_diagonal() const {
auto value_ptr = GetAttr(kUnitDiagonal);
return GetValue<bool>(value_ptr);
}
void SolveTriangular::set_lower(bool lower) { (void)AddAttr(kLower, api::MakeValue(lower)); }
bool SolveTriangular::get_lower() const {
auto value_ptr = GetAttr(kLower);
return GetValue<bool>(value_ptr);
}
void SolveTriangular::set_trans(std::string trans) { (void)AddAttr(kTrans, api::MakeValue(trans)); }
std::string SolveTriangular::get_trans() const {
auto value_ptr = GetAttr(kTrans);
return GetValue<std::string>(value_ptr);
}
MIND_API_OPERATOR_IMPL(SolveTriangular, BaseOperator);
AbstractBasePtr SolveTriangularInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = SolveTriangularInferType(primitive, input_args);
auto infer_shape = SolveTriangularInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SolveTriangular, prim::kPrimSolveTriangular, SolveTriangularInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
#ifndef MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_
#define MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_
#include <map>
#include <set>
#include <string>
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSolveTriangular = "SolveTriangular";
/// \brief Assert defined MatrixSolve operator prototype.
class MIND_API SolveTriangular : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SolveTriangular);
/// \brief Constructor.
SolveTriangular() : BaseOperator(kNameSolveTriangular) { InitIOName({"a", "b"}, {"output"}); }
/// \brief Init.
void Init(bool lower, bool unit_diagonal, std::string trans);
/// \brief Method to set unit_diagonal attributes.
void set_unit_diagonal(bool unit_diagonal);
/// \brief Method to get unit_diagonal attributes.
bool get_unit_diagonal() const;
/// \brief Method to set lower attributes.
void set_lower(bool lower);
/// \brief Method to get lower attributes.
bool get_lower() const;
/// \brief Method to set trans attributes.
void set_trans(std::string trans);
/// \brief Method to get trans attributes.
std::string get_trans() const;
};
abstract::AbstractBasePtr SolveTriangularInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSolveTriangularPtr = std::shared_ptr<SolveTriangular>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_

View File

@ -18,7 +18,7 @@ from .._checkparam import Validator as validator
from ..common import dtype as mstype
class SolveTriangular(PrimitiveWithInfer):
class SolveTriangular(Primitive):
"""
Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
@ -87,49 +87,6 @@ class SolveTriangular(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
def __infer__(self, a, b):
a_shape = a['shape']
b_shape = b['shape']
# shape match
b_vector = len(b_shape) == len(a_shape) - 1
if len(a_shape) < 2:
raise ValueError(f"For '{self.name}', the dimension of `a` should be at least 2,"
f" but got {len(a_shape)} dimensions.")
b_len = 1 if b_vector else 2
if len(b_shape) < b_len:
raise ValueError(f"For '{self.name}', the dimension of `b` should be at least {b_len},"
f" but got {len(b_shape)} dimensions.")
if len(a_shape) != len(b_shape) and len(a_shape) - 1 != len(b_shape):
raise ValueError(f"For '{self.name}', the dimension of `b` should be 'a.dim' or 'a.dim' - 1, "
f"which is {len(a_shape)} or {len(a_shape) - 1}, but got {len(b_shape)} dimensions.")
if a_shape[-1] != a_shape[-2]:
raise ValueError(f"For '{self.name}', the last two dimensions of `a` should be the same,"
f" but got shape of {a_shape}."
f" Please make sure that the shape of `a` be like [..., N, N]")
if a_shape[-2] != b_shape[-b_len]:
raise ValueError(f"For '{self.name}', the last two dimensions of `a` and `b` should be matched,"
f" but got shape of {a_shape} and {b_shape}."
f" Please make sure that the shape of `a` and `b` be like"
f" [..., N, N] X [..., N, M] or [..., N, N] X [..., N].")
if a_shape[:-2] != b_shape[:-b_len]:
raise ValueError(f"For '{self.name}', the batch dimensions of `a` and `b` should all be the same,"
f" but got shape of {a_shape} and {b_shape}."
f" Please make sure that the shape of `a` and `b` be like"
f" [a, b, c, ..., N, N] X [a, b, c, ..., N, M] or"
f" [a, b, c, ..., N, N] X [a, b, c, ..., N].")
validator.check_scalar_or_tensor_types_same({"a_dtype": a['dtype'], "b_dtype": b['dtype']},
[mstype.float32, mstype.float64], self.name)
return {
'shape': tuple(b_shape),
'dtype': a['dtype'],
'value': None
}
def infer_dtype(self, a_dtype, b_dtype):
del b_dtype
return a_dtype
class Eigh(PrimitiveWithInfer):
"""

View File

@ -20,7 +20,8 @@ import numpy as np
import scipy as scp
from scipy.linalg import solve_triangular, eig, eigvals
from mindspore import Tensor, context
from mindspore import Tensor, context, nn
from mindspore.common import dtype as mstype
from mindspore.ops.operations.math_ops import Cholesky
from mindspore.scipy.ops import Eigh, Eig, SolveTriangular
from mindspore.scipy.utils import _nd_transpose
@ -29,6 +30,15 @@ from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_ma
np.random.seed(0)
class SolveTriangularNet(nn.Cell):
def __init__(self, lower: bool = False, unit_diagonal: bool = False, trans: str = 'N'):
super(SolveTriangularNet, self).__init__()
self.solve = SolveTriangular(lower, unit_diagonal, trans)
def construct(self, a, b):
return self.solve(a, b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@ -317,7 +327,14 @@ def test_solve_triangular_matrix(shape: int, dtype, lower: bool, unit_diagonal:
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
b = np.random.random((m, n)).astype(dtype)
expect = solve_triangular(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
output = SolveTriangular(lower, unit_diagonal, trans)(Tensor(a), Tensor(b)).asnumpy()
type_convert = {np.float32: mstype.float32, np.float64: mstype.float64}
dynamic_net = SolveTriangularNet(lower, unit_diagonal, trans)
place_holder = Tensor(shape=[None, None], dtype=type_convert.get(dtype))
dynamic_net.set_inputs(place_holder, place_holder)
output = dynamic_net(Tensor(a), Tensor(b)).asnumpy()
np.testing.assert_almost_equal(expect, output, decimal=5)