add dynamic support for solve triangular
This commit is contained in:
parent
e24f136909
commit
0bdfd3e052
|
@ -1,166 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::Lower;
|
||||
using Eigen::Map;
|
||||
using Eigen::MatrixBase;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::UnitLower;
|
||||
using Eigen::UnitUpper;
|
||||
using Eigen::Upper;
|
||||
template <typename T, int Major>
|
||||
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
|
||||
constexpr auto kSolveTriangularInputsNum = 2;
|
||||
constexpr auto kSolveTriangularOutputsNum = 1;
|
||||
constexpr auto kAVectorxDimNum = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
void MatrixTriangularSolveCpuKernelMod::InitShape(const CNodePtr &kernel_node) {
|
||||
auto a_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0));
|
||||
auto b_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1));
|
||||
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
|
||||
size_t a_dims = a_shape.size();
|
||||
size_t aRowIndex = a_dims - kRowIndex;
|
||||
m_ = a_shape[aRowIndex];
|
||||
size_t b_sims = b_shape.size();
|
||||
bool vector_b = b_sims == a_dims - 1;
|
||||
if (vector_b) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
n_ = b_shape[b_sims - 1];
|
||||
}
|
||||
batch_ = 1;
|
||||
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
|
||||
batch_ *= a_shape[batch];
|
||||
}
|
||||
}
|
||||
|
||||
void MatrixTriangularSolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
InitShape(kernel_node);
|
||||
if (common::AnfAlgo::HasNodeAttr(ADJOINT, kernel_node)) {
|
||||
// MatrixTriangularSolve attribute
|
||||
trans_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT);
|
||||
if (common::AnfAlgo::HasNodeAttr(TRANS, kernel_node)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
|
||||
}
|
||||
} else {
|
||||
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
unit_diagonal_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL);
|
||||
const std::string trans = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS);
|
||||
if (trans == "N") {
|
||||
trans_ = false;
|
||||
} else if (trans == "T") {
|
||||
trans_ = true;
|
||||
} else if (trans == "C") {
|
||||
trans_ = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans
|
||||
<< "].";
|
||||
}
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "SolveTriangular does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename Derived_a, typename Derived_b, typename T>
|
||||
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
|
||||
bool lower, bool unit_diagonal) {
|
||||
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
|
||||
if (unit_diagonal) {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<UnitLower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
|
||||
}
|
||||
} else {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<Lower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<Upper>().solve(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
|
||||
|
||||
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t a_batch_size = m_ * m_;
|
||||
size_t b_batch_size = m_ * n_;
|
||||
size_t output_batch_size = m_ * n_;
|
||||
|
||||
for (size_t i = 0; i < batch_; ++i) {
|
||||
T *a_batch_addr = a_addr + i * a_batch_size;
|
||||
T *b_batch_addr = b_addr + i * b_batch_size;
|
||||
T *output_batch_addr = output_addr + i * output_batch_size;
|
||||
|
||||
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
|
||||
if (trans_) {
|
||||
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
|
||||
solve(a, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
|
||||
} else {
|
||||
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
|
||||
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixTriangularSolveCpuKernelMod::MatrixTriangularSolveFunc>>
|
||||
MatrixTriangularSolveCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> MatrixTriangularSolveCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MatrixTriangularSolveFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SolveTriangular, MatrixTriangularSolveCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/solve_triangular.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::Lower;
|
||||
using Eigen::Map;
|
||||
using Eigen::MatrixBase;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::UnitLower;
|
||||
using Eigen::UnitUpper;
|
||||
using Eigen::Upper;
|
||||
template <typename T, int Major>
|
||||
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
|
||||
using KernelRunFunc = SolveTriangularCpuKernelMod::KernelRunFunc;
|
||||
constexpr auto kSolveTriangularInputsNum = 2;
|
||||
constexpr auto kSolveTriangularOutputsNum = 1;
|
||||
int SolveTriangularCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto a_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
|
||||
auto b_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
|
||||
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
|
||||
size_t a_dims = a_shape.size();
|
||||
size_t b_dims = b_shape.size();
|
||||
m_ = a_shape[a_dims - kIndex2];
|
||||
n_ = (b_dims == a_dims - 1) ? 1 : b_shape[b_dims - 1];
|
||||
batch_ = std::accumulate(a_shape.begin(), a_shape.end() - kIndex2, int64_t(1), std::multiplies{});
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool SolveTriangularCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
|
||||
auto kernel_ptr = std::make_shared<ops::SolveTriangular>(base_operator->GetPrim());
|
||||
lower_ = kernel_ptr->get_lower();
|
||||
unit_diagonal_ = kernel_ptr->get_unit_diagonal();
|
||||
const std::string trans = kernel_ptr->get_trans();
|
||||
if (trans == "N") {
|
||||
trans_ = false;
|
||||
} else if (trans == "T") {
|
||||
trans_ = true;
|
||||
} else if (trans == "C") {
|
||||
// currently does not support complex.
|
||||
trans_ = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans << "].";
|
||||
}
|
||||
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Derived_a, typename Derived_b, typename T>
|
||||
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
|
||||
bool lower, bool unit_diagonal) {
|
||||
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
|
||||
if (unit_diagonal) {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<UnitLower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
|
||||
}
|
||||
} else {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<Lower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<Upper>().solve(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SolveTriangularCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
|
||||
|
||||
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t a_batch_size = m_ * m_;
|
||||
size_t b_batch_size = m_ * n_;
|
||||
size_t output_batch_size = m_ * n_;
|
||||
|
||||
for (size_t i = 0; i < batch_; ++i) {
|
||||
T *a_batch_addr = a_addr + i * a_batch_size;
|
||||
T *b_batch_addr = b_addr + i * b_batch_size;
|
||||
T *output_batch_addr = output_addr + i * output_batch_size;
|
||||
|
||||
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
|
||||
if (trans_) {
|
||||
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
|
||||
solve(a, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
|
||||
} else {
|
||||
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
|
||||
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SolveTriangularCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SolveTriangularCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&SolveTriangularCpuKernelMod::LaunchKernel<double>},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SolveTriangular, SolveTriangularCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -19,35 +19,36 @@
|
|||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class SolveTriangularCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<SolveTriangularCpuKernelMod> {
|
||||
public:
|
||||
MatrixTriangularSolveCpuKernelMod() = default;
|
||||
~MatrixTriangularSolveCpuKernelMod() override = default;
|
||||
SolveTriangularCpuKernelMod() = default;
|
||||
~SolveTriangularCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void InitShape(const CNodePtr &kernel_node);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MatrixTriangularSolveFunc =
|
||||
std::function<bool(MatrixTriangularSolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixTriangularSolveFunc>> func_list_;
|
||||
MatrixTriangularSolveFunc kernel_func_;
|
||||
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
|
@ -1,30 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SolveTriangular,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MatrixTriangularSolveGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SolveTriangular,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixTriangularSolveGpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,259 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kAVectorxDimNum = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
constexpr size_t kShape3D = 3;
|
||||
constexpr size_t kIndexAArray = 0;
|
||||
constexpr size_t kIndexDstArray = 1;
|
||||
constexpr size_t kIndexBBuffer = 2;
|
||||
constexpr size_t kIndexBTransposeShape = 3;
|
||||
constexpr size_t kIndexBTransposeAxis = 4;
|
||||
template <typename T>
|
||||
class MatrixTriangularSolveGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
MatrixTriangularSolveGpuKernelMod() = default;
|
||||
~MatrixTriangularSolveGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
auto inputa_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
// if b is not a vector, solve b in the workspace
|
||||
T *dst = nullptr;
|
||||
if (n_ == 1) {
|
||||
dst = output_addr;
|
||||
} else {
|
||||
dst = GetDeviceAddress<T>(workspace, kIndexBBuffer);
|
||||
}
|
||||
|
||||
const size_t batched_b_size = batch_ * m_ * n_;
|
||||
if (n_ == 1) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dst, inputb_addr, batched_b_size * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync dst failed");
|
||||
} else {
|
||||
// No matter how many batch dimensions the batched matrix b has, use their cumulative multiplication batch.
|
||||
// In order to convert row major matrix b(batch, m, n) to col major matrix b'(batch, m, n),
|
||||
// the following operation is equivalent to:
|
||||
// b' = b.tarnspose(batch, n, m).reshape(batch, m, n)
|
||||
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
|
||||
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
|
||||
MatrixTransposeND(inputb_addr, {batch_, m_, n_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape,
|
||||
dev_transpose_b_axis, dst, reinterpret_cast<cudaStream_t>(stream_ptr), kernel_name_);
|
||||
}
|
||||
|
||||
// index calculation
|
||||
auto device_a_array_addr = GetDeviceAddress<T *>(workspace, kIndexAArray);
|
||||
auto device_dst_array_addr = GetDeviceAddress<T *>(workspace, kIndexDstArray);
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
host_a_array_[i] = inputa_addr + i * m_ * m_;
|
||||
host_dst_array_[i] = dst + i * m_ * n_;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(device_a_array_addr, host_a_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(device_dst_array_addr, host_dst_array_.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
T alpha = 1;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
|
||||
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
|
||||
"cublas trsm Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasDtrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
|
||||
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
|
||||
"cublas trsm Fail");
|
||||
}
|
||||
|
||||
// if x is not a vector, do transpose
|
||||
if (n_ != 1) {
|
||||
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
|
||||
// the following operation is equivalent to:
|
||||
// x = x'.reshape(n, m).T
|
||||
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
|
||||
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
|
||||
MatrixTransposeND(dst, {batch_, n_, m_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
|
||||
output_addr, reinterpret_cast<cudaStream_t>(stream_ptr), kernel_name_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
|
||||
InitShape(kernel_node);
|
||||
if (is_dynamic_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
lda_ = SizeToInt(m_);
|
||||
ldb_ = SizeToInt(m_);
|
||||
|
||||
if (common::AnfAlgo::HasNodeAttr("adjoint", kernel_node)) {
|
||||
// MatrixTriangularSolve attribute
|
||||
bool trans = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "adjoint");
|
||||
// converting row major to col major is the same as reverting the trans flag
|
||||
trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
if (common::AnfAlgo::HasNodeAttr("trans", kernel_node)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
|
||||
}
|
||||
} else {
|
||||
bool lower = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
|
||||
// reverting the trans flag by default, so also flip the lower flag
|
||||
lower = !lower;
|
||||
uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
bool unit_diagonal = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal");
|
||||
unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
const std::string trans = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
|
||||
// converting row major to col major is the same as reverting the trans flag
|
||||
if (trans == "N") {
|
||||
trans_ = CUBLAS_OP_T;
|
||||
} else if (trans == "T") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else if (trans == "C") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "].";
|
||||
}
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t unit_size = sizeof(T);
|
||||
size_t a_size = batch_ * m_ * m_ * unit_size;
|
||||
size_t b_size = batch_ * m_ * n_ * unit_size;
|
||||
input_size_list_ = {a_size, b_size};
|
||||
output_size_list_ = {b_size};
|
||||
if (n_ != 1) {
|
||||
workspace_size_list_ = {
|
||||
// workspace for batched a
|
||||
batch_ * sizeof(T *),
|
||||
// workspace for batched b
|
||||
batch_ * sizeof(T *),
|
||||
// workspace for transposed b
|
||||
b_size,
|
||||
// workspace for b transpose shape
|
||||
kShape3D * sizeof(size_t *),
|
||||
// workspace for b transpose axis
|
||||
kShape3D * sizeof(size_t *),
|
||||
};
|
||||
} else {
|
||||
workspace_size_list_ = {
|
||||
// workspace for batched a
|
||||
batch_ * sizeof(T *),
|
||||
// workspace for batched b
|
||||
batch_ * sizeof(T *),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
void InitShape(const CNodePtr &kernel_node) {
|
||||
is_dynamic_ = false;
|
||||
auto a_shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto b_shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
if (AnfAlgo::IsShapesDynamic({a_shape_signed, b_shape_signed})) {
|
||||
is_dynamic_ = true;
|
||||
return;
|
||||
}
|
||||
auto a_shape = Convert2SizeTClipNeg(a_shape_signed);
|
||||
auto b_shape = Convert2SizeTClipNeg(b_shape_signed);
|
||||
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(a_shape, kernel_name_, "input_a") || CHECK_SHAPE_NULL(b_shape, kernel_name_, "input_b");
|
||||
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
|
||||
size_t a_dims = a_shape.size();
|
||||
size_t aRowIndex = a_dims - kRowIndex;
|
||||
m_ = a_shape[aRowIndex];
|
||||
size_t b_sims = b_shape.size();
|
||||
bool vector_b = b_sims == a_dims - 1;
|
||||
if (vector_b) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
n_ = b_shape[b_sims - 1];
|
||||
}
|
||||
batch_ = 1;
|
||||
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
|
||||
batch_ *= a_shape[batch];
|
||||
}
|
||||
host_a_array_.resize(batch_);
|
||||
host_dst_array_.resize(batch_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
size_t batch_{1};
|
||||
int lda_{0};
|
||||
int ldb_{0};
|
||||
bool is_null_input_{false};
|
||||
bool is_dynamic_{false};
|
||||
std::vector<T *> host_a_array_;
|
||||
std::vector<T *> host_dst_array_;
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_{CUBLAS_FILL_MODE_UPPER};
|
||||
cublasOperation_t trans_{CUBLAS_OP_N};
|
||||
cublasDiagType_t unit_diagonal_{CUBLAS_DIAG_NON_UNIT};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h"
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ops/solve_triangular.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using KernelRunFunc = SolveTriangularGpuKernelMod::KernelRunFunc;
|
||||
bool SolveTriangularGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
|
||||
auto kernel_ptr = std::make_shared<ops::SolveTriangular>(base_operator->GetPrim());
|
||||
|
||||
bool lower = kernel_ptr->get_lower();
|
||||
// reverting the trans flag by default, so also flip the lower flag
|
||||
lower = !lower;
|
||||
uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
|
||||
bool unit_diagonal = kernel_ptr->get_unit_diagonal();
|
||||
unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
const std::string trans = kernel_ptr->get_trans();
|
||||
if (trans == "N") {
|
||||
trans_ = CUBLAS_OP_T;
|
||||
} else if (trans == "T") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else if (trans == "C") {
|
||||
// currently does not support complex.
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' must be in ['N', 'T', 'C'], but got [" << trans << "].";
|
||||
}
|
||||
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
int SolveTriangularGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto a_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
|
||||
auto b_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
|
||||
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(a_shape, kernel_name_, "input_a") || CHECK_SHAPE_NULL(b_shape, kernel_name_, "input_b");
|
||||
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
|
||||
size_t a_dims = a_shape.size();
|
||||
size_t b_dims = b_shape.size();
|
||||
m_ = a_shape[a_dims - kIndex2];
|
||||
n_ = (b_dims == a_dims - 1) ? 1 : b_shape[b_dims - 1];
|
||||
batch_ = std::accumulate(a_shape.begin(), a_shape.end() - kIndex2, int64_t(1), std::multiplies{});
|
||||
|
||||
lda_ = SizeToInt(m_);
|
||||
ldb_ = SizeToInt(m_);
|
||||
|
||||
const size_t unit_size = GetTypeByte(TypeIdToType(inputs.at(kIndex0)->GetDtype()));
|
||||
constexpr size_t pointer_size = sizeof(float *);
|
||||
size_t b_size = batch_ * m_ * n_ * unit_size;
|
||||
workspace_size_list_.clear();
|
||||
if (n_ != 1) {
|
||||
workspace_size_list_ = {
|
||||
// workspace for batched a
|
||||
batch_ * pointer_size,
|
||||
// workspace for batched b
|
||||
batch_ * pointer_size,
|
||||
// workspace for transposed b
|
||||
b_size,
|
||||
// workspace for b transpose shape
|
||||
kShape3D * sizeof(size_t *),
|
||||
// workspace for b transpose axis
|
||||
kShape3D * sizeof(size_t *),
|
||||
};
|
||||
} else {
|
||||
workspace_size_list_ = {
|
||||
// workspace for batched a
|
||||
batch_ * pointer_size,
|
||||
// workspace for batched b
|
||||
batch_ * pointer_size,
|
||||
};
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SolveTriangularGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, cuda_stream_), "cublasSetStream failed");
|
||||
auto inputa_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
std::vector<T *> host_a_array(batch_);
|
||||
std::vector<T *> host_dst_array(batch_);
|
||||
|
||||
// if b is not a vector, solve b in the workspace
|
||||
T *dst = nullptr;
|
||||
if (n_ == 1) {
|
||||
dst = output_addr;
|
||||
} else {
|
||||
dst = GetDeviceAddress<T>(workspace, kIndexBBuffer);
|
||||
}
|
||||
|
||||
const size_t batched_b_size = batch_ * m_ * n_;
|
||||
if (n_ == 1) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(dst, inputb_addr, batched_b_size * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream_),
|
||||
"cudaMemcpyAsync dst failed");
|
||||
} else {
|
||||
// No matter how many batch dimensions the batched matrix b has, use their cumulative multiplication batch.
|
||||
// In order to convert row major matrix b(batch, m, n) to col major matrix b'(batch, m, n),
|
||||
// the following operation is equivalent to:
|
||||
// b' = b.tarnspose(batch, n, m).reshape(batch, m, n)
|
||||
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
|
||||
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
|
||||
MatrixTransposeND(inputb_addr, {batch_, m_, n_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
|
||||
dst, cuda_stream_, kernel_name_);
|
||||
}
|
||||
|
||||
// index calculation
|
||||
auto device_a_array_addr = GetDeviceAddress<T *>(workspace, kIndexAArray);
|
||||
auto device_dst_array_addr = GetDeviceAddress<T *>(workspace, kIndexDstArray);
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
host_a_array[i] = inputa_addr + i * m_ * m_;
|
||||
host_dst_array[i] = dst + i * m_ * n_;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(device_a_array_addr, host_a_array.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, cuda_stream_),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(device_dst_array_addr, host_dst_array.data(), sizeof(T *) * batch_,
|
||||
cudaMemcpyHostToDevice, cuda_stream_),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
T alpha = 1;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
|
||||
cublasStrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
|
||||
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
|
||||
"cublas trsm Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
|
||||
cublasDtrsmBatched(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, &alpha,
|
||||
device_a_array_addr, lda_, device_dst_array_addr, ldb_, batch_),
|
||||
"cublas trsm Fail");
|
||||
}
|
||||
|
||||
// if x is not a vector, do transpose
|
||||
if (n_ != 1) {
|
||||
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
|
||||
// the following operation is equivalent to:
|
||||
// x = x'.reshape(n, m).T
|
||||
auto dev_transpose_b_shape = GetDeviceAddress<size_t>(workspace, kIndexBTransposeShape);
|
||||
auto dev_transpose_b_axis = GetDeviceAddress<size_t>(workspace, kIndexBTransposeAxis);
|
||||
MatrixTransposeND(dst, {batch_, n_, m_}, {kDim0, kDim2, kDim1}, dev_transpose_b_shape, dev_transpose_b_axis,
|
||||
output_addr, cuda_stream_, kernel_name_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SolveTriangularGpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SolveTriangularGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&SolveTriangularGpuKernelMod::LaunchKernel<double>},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SolveTriangular, SolveTriangularGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kAVectorxDimNum = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
constexpr size_t kShape3D = 3;
|
||||
constexpr size_t kIndexAArray = 0;
|
||||
constexpr size_t kIndexDstArray = 1;
|
||||
constexpr size_t kIndexBBuffer = 2;
|
||||
constexpr size_t kIndexBTransposeShape = 3;
|
||||
constexpr size_t kIndexBTransposeAxis = 4;
|
||||
|
||||
class SolveTriangularGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<SolveTriangularGpuKernelMod> {
|
||||
public:
|
||||
SolveTriangularGpuKernelMod() = default;
|
||||
~SolveTriangularGpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
size_t batch_{1};
|
||||
int lda_{0};
|
||||
int ldb_{0};
|
||||
bool is_null_input_{false};
|
||||
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_{CUBLAS_FILL_MODE_UPPER};
|
||||
cublasOperation_t trans_{CUBLAS_OP_N};
|
||||
cublasDiagType_t unit_diagonal_{CUBLAS_DIAG_NON_UNIT};
|
||||
|
||||
cudaStream_t cuda_stream_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
|
|
@ -124,6 +124,7 @@ constexpr auto kKSize = "ksize";
|
|||
constexpr auto kKsizes = "ksizes";
|
||||
constexpr auto kKernelType = "kernel_type";
|
||||
constexpr auto kLimit = "limit";
|
||||
constexpr auto kLower = "lower";
|
||||
constexpr auto kMagSquare = "mag_square";
|
||||
constexpr auto kMax = "max";
|
||||
constexpr auto kMaxSizes = "max_sizes";
|
||||
|
@ -228,10 +229,12 @@ constexpr auto kSummarize = "summarize";
|
|||
constexpr auto kTimeMajor = "time_major";
|
||||
constexpr auto kTolerance = "tolerance";
|
||||
constexpr auto kTopK = "top_k";
|
||||
constexpr auto kTrans = "trans";
|
||||
constexpr auto kTransposeA = "transpose_a";
|
||||
constexpr auto kTransposeB = "transpose_b";
|
||||
constexpr auto kNegativeSlope = "negative_slope";
|
||||
constexpr auto kType = "type";
|
||||
constexpr auto kUnitDiagonal = "unit_diagonal";
|
||||
constexpr auto kUpdateSlots = "update_slots";
|
||||
constexpr auto kUseAxis = "use_axis";
|
||||
constexpr auto kUseLocking = "use_locking";
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/solve_triangular.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SolveTriangularInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
auto a_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(a_shape_ptr);
|
||||
auto b_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(b_shape_ptr);
|
||||
|
||||
if (a_shape_ptr->IsDimUnknown() || b_shape_ptr->IsDimUnknown()) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeDimAny});
|
||||
}
|
||||
|
||||
if (a_shape_ptr->IsDynamic() || b_shape_ptr->IsDynamic()) {
|
||||
return b_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
auto a_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(a_shape_ptr)[kShape];
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(b_shape_ptr)[kShape];
|
||||
|
||||
constexpr size_t square_size = 2;
|
||||
const size_t expected_b_dim = (b_shape.size() == a_shape.size() - 1) ? 1 : square_size;
|
||||
|
||||
size_t a_dim = a_shape.size();
|
||||
size_t b_dim = b_shape.size();
|
||||
|
||||
CheckAndConvertUtils::CheckValue<size_t>("dim of matrix a", a_dim, kGreaterEqual, square_size, prim_name);
|
||||
CheckAndConvertUtils::CheckValue<size_t>("dim of matrix b", b_dim, kGreaterEqual, expected_b_dim, prim_name);
|
||||
|
||||
if ((a_dim != b_dim) && (a_dim - 1 != b_dim)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the dimension of `b` should be 'a.dim' or 'a.dim' - 1, which is " << a_dim << " or "
|
||||
<< (a_dim - 1) << ", but got " << b_dim << " dimensions.";
|
||||
}
|
||||
if (a_shape[a_dim - 1] != a_shape[a_dim - square_size]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the last two dimensions of `a` should be the same, but got shape of " << a_shape
|
||||
<< ". Please make sure that the shape of `a` be like [..., N, N].";
|
||||
}
|
||||
|
||||
if (a_shape[a_dim - square_size] != b_shape[b_dim - expected_b_dim]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the last two dimensions of `a` and `b` should be matched, but got shape of "
|
||||
<< a_shape << " and " << b_shape
|
||||
<< ". Please make sure that the shape of `a` and `b` be like [..., N, N] X [..., N, M] or "
|
||||
"[..., N, N] X [..., N].";
|
||||
}
|
||||
|
||||
if (!std::equal(a_shape.begin(), a_shape.begin() + (a_dim - square_size), b_shape.begin(),
|
||||
b_shape.begin() + (b_dim - expected_b_dim))) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the batch dimensions of `a` and `b` should all be the same, but got shape of "
|
||||
<< a_shape << " and " << b_shape
|
||||
<< ". Please make sure that the shape of `a` and `b` be like [a, b, c, ..., N, N] X [a, "
|
||||
"b, c, ..., N, M] or [a, b, c, ..., N, N] X [a, b, c, ..., N].";
|
||||
}
|
||||
|
||||
return b_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr SolveTriangularInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto a_dtype = input_args[kInputIndex0]->BuildType();
|
||||
auto b_dtype = input_args[kInputIndex1]->BuildType();
|
||||
|
||||
const std::map<std::string, TypePtr> type_dict = {{"a type", a_dtype}, {"b type", b_dtype}};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, {kFloat32, kFloat64}, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
void SolveTriangular::Init(bool lower, bool unit_diagonal, std::string trans) { set_unit_diagonal(unit_diagonal); }
|
||||
|
||||
void SolveTriangular::set_unit_diagonal(bool unit_diagonal) {
|
||||
(void)AddAttr(kUnitDiagonal, api::MakeValue(unit_diagonal));
|
||||
}
|
||||
|
||||
bool SolveTriangular::get_unit_diagonal() const {
|
||||
auto value_ptr = GetAttr(kUnitDiagonal);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void SolveTriangular::set_lower(bool lower) { (void)AddAttr(kLower, api::MakeValue(lower)); }
|
||||
|
||||
bool SolveTriangular::get_lower() const {
|
||||
auto value_ptr = GetAttr(kLower);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void SolveTriangular::set_trans(std::string trans) { (void)AddAttr(kTrans, api::MakeValue(trans)); }
|
||||
|
||||
std::string SolveTriangular::get_trans() const {
|
||||
auto value_ptr = GetAttr(kTrans);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SolveTriangular, BaseOperator);
|
||||
|
||||
AbstractBasePtr SolveTriangularInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = SolveTriangularInferType(primitive, input_args);
|
||||
auto infer_shape = SolveTriangularInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SolveTriangular, prim::kPrimSolveTriangular, SolveTriangularInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_
|
||||
#define MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSolveTriangular = "SolveTriangular";
|
||||
/// \brief Assert defined MatrixSolve operator prototype.
|
||||
class MIND_API SolveTriangular : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SolveTriangular);
|
||||
/// \brief Constructor.
|
||||
SolveTriangular() : BaseOperator(kNameSolveTriangular) { InitIOName({"a", "b"}, {"output"}); }
|
||||
/// \brief Init.
|
||||
void Init(bool lower, bool unit_diagonal, std::string trans);
|
||||
|
||||
/// \brief Method to set unit_diagonal attributes.
|
||||
void set_unit_diagonal(bool unit_diagonal);
|
||||
/// \brief Method to get unit_diagonal attributes.
|
||||
bool get_unit_diagonal() const;
|
||||
|
||||
/// \brief Method to set lower attributes.
|
||||
void set_lower(bool lower);
|
||||
/// \brief Method to get lower attributes.
|
||||
bool get_lower() const;
|
||||
|
||||
/// \brief Method to set trans attributes.
|
||||
void set_trans(std::string trans);
|
||||
/// \brief Method to get trans attributes.
|
||||
std::string get_trans() const;
|
||||
};
|
||||
abstract::AbstractBasePtr SolveTriangularInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSolveTriangularPtr = std::shared_ptr<SolveTriangular>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_SOLVE_TRIANGULAR_H_
|
|
@ -18,7 +18,7 @@ from .._checkparam import Validator as validator
|
|||
from ..common import dtype as mstype
|
||||
|
||||
|
||||
class SolveTriangular(PrimitiveWithInfer):
|
||||
class SolveTriangular(Primitive):
|
||||
"""
|
||||
Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
|
||||
|
||||
|
@ -87,49 +87,6 @@ class SolveTriangular(PrimitiveWithInfer):
|
|||
|
||||
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
|
||||
|
||||
def __infer__(self, a, b):
|
||||
a_shape = a['shape']
|
||||
b_shape = b['shape']
|
||||
# shape match
|
||||
b_vector = len(b_shape) == len(a_shape) - 1
|
||||
if len(a_shape) < 2:
|
||||
raise ValueError(f"For '{self.name}', the dimension of `a` should be at least 2,"
|
||||
f" but got {len(a_shape)} dimensions.")
|
||||
b_len = 1 if b_vector else 2
|
||||
if len(b_shape) < b_len:
|
||||
raise ValueError(f"For '{self.name}', the dimension of `b` should be at least {b_len},"
|
||||
f" but got {len(b_shape)} dimensions.")
|
||||
if len(a_shape) != len(b_shape) and len(a_shape) - 1 != len(b_shape):
|
||||
raise ValueError(f"For '{self.name}', the dimension of `b` should be 'a.dim' or 'a.dim' - 1, "
|
||||
f"which is {len(a_shape)} or {len(a_shape) - 1}, but got {len(b_shape)} dimensions.")
|
||||
if a_shape[-1] != a_shape[-2]:
|
||||
raise ValueError(f"For '{self.name}', the last two dimensions of `a` should be the same,"
|
||||
f" but got shape of {a_shape}."
|
||||
f" Please make sure that the shape of `a` be like [..., N, N]")
|
||||
if a_shape[-2] != b_shape[-b_len]:
|
||||
raise ValueError(f"For '{self.name}', the last two dimensions of `a` and `b` should be matched,"
|
||||
f" but got shape of {a_shape} and {b_shape}."
|
||||
f" Please make sure that the shape of `a` and `b` be like"
|
||||
f" [..., N, N] X [..., N, M] or [..., N, N] X [..., N].")
|
||||
if a_shape[:-2] != b_shape[:-b_len]:
|
||||
raise ValueError(f"For '{self.name}', the batch dimensions of `a` and `b` should all be the same,"
|
||||
f" but got shape of {a_shape} and {b_shape}."
|
||||
f" Please make sure that the shape of `a` and `b` be like"
|
||||
f" [a, b, c, ..., N, N] X [a, b, c, ..., N, M] or"
|
||||
f" [a, b, c, ..., N, N] X [a, b, c, ..., N].")
|
||||
|
||||
validator.check_scalar_or_tensor_types_same({"a_dtype": a['dtype'], "b_dtype": b['dtype']},
|
||||
[mstype.float32, mstype.float64], self.name)
|
||||
return {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': a['dtype'],
|
||||
'value': None
|
||||
}
|
||||
|
||||
def infer_dtype(self, a_dtype, b_dtype):
|
||||
del b_dtype
|
||||
return a_dtype
|
||||
|
||||
|
||||
class Eigh(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -20,7 +20,8 @@ import numpy as np
|
|||
import scipy as scp
|
||||
from scipy.linalg import solve_triangular, eig, eigvals
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import Tensor, context, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.operations.math_ops import Cholesky
|
||||
from mindspore.scipy.ops import Eigh, Eig, SolveTriangular
|
||||
from mindspore.scipy.utils import _nd_transpose
|
||||
|
@ -29,6 +30,15 @@ from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_ma
|
|||
np.random.seed(0)
|
||||
|
||||
|
||||
class SolveTriangularNet(nn.Cell):
|
||||
def __init__(self, lower: bool = False, unit_diagonal: bool = False, trans: str = 'N'):
|
||||
super(SolveTriangularNet, self).__init__()
|
||||
self.solve = SolveTriangular(lower, unit_diagonal, trans)
|
||||
|
||||
def construct(self, a, b):
|
||||
return self.solve(a, b)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -317,7 +327,14 @@ def test_solve_triangular_matrix(shape: int, dtype, lower: bool, unit_diagonal:
|
|||
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
|
||||
b = np.random.random((m, n)).astype(dtype)
|
||||
expect = solve_triangular(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
output = SolveTriangular(lower, unit_diagonal, trans)(Tensor(a), Tensor(b)).asnumpy()
|
||||
|
||||
type_convert = {np.float32: mstype.float32, np.float64: mstype.float64}
|
||||
|
||||
dynamic_net = SolveTriangularNet(lower, unit_diagonal, trans)
|
||||
place_holder = Tensor(shape=[None, None], dtype=type_convert.get(dtype))
|
||||
dynamic_net.set_inputs(place_holder, place_holder)
|
||||
|
||||
output = dynamic_net(Tensor(a), Tensor(b)).asnumpy()
|
||||
np.testing.assert_almost_equal(expect, output, decimal=5)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue