forked from mindspore-Ecosystem/mindspore
cholesky dynamic shape
This commit is contained in:
parent
01ca5e1555
commit
bcc4a9b9a1
|
@ -53,36 +53,45 @@ void CholeskyCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size
|
|||
outer_batch_ /= ((*row) * (*col));
|
||||
}
|
||||
|
||||
void CholeskyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
|
||||
auto input_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex));
|
||||
InitMatrixInfo(input_shape, &input_row_, &input_col_);
|
||||
auto output_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex));
|
||||
InitMatrixInfo(output_shape, &output_row_, &output_col_);
|
||||
if (common::AnfAlgo::HasNodeAttr("upper", kernel_node)) {
|
||||
flag_ = false;
|
||||
upper_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "upper");
|
||||
}
|
||||
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
if (common::AnfAlgo::HasNodeAttr(CLEAN, kernel_node)) {
|
||||
clean_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(LOWER, kernel_node)) {
|
||||
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
bool CholeskyCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
dtype_ = inputs[kIndex0]->GetDtype();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Cholesky does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
if (base_operator->HasAttr("upper")) {
|
||||
flag_ = false;
|
||||
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
|
||||
}
|
||||
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
if (base_operator->HasAttr(CLEAN)) {
|
||||
clean_ = GetValue<bool>(base_operator->GetAttr(CLEAN));
|
||||
}
|
||||
if (base_operator->HasAttr(LOWER)) {
|
||||
lower_ = GetValue<bool>(base_operator->GetAttr(LOWER));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int CholeskyCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_shape = LongVecToSizeVec(inputs[kInputIndex]->GetShapeVector());
|
||||
InitMatrixInfo(input_shape, &input_row_, &input_col_);
|
||||
auto output_shape = LongVecToSizeVec(inputs[kOutputIndex]->GetShapeVector());
|
||||
InitMatrixInfo(output_shape, &output_row_, &output_col_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -14,21 +14,27 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class CholeskyCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CholeskyCpuKernelMod() = default;
|
||||
~CholeskyCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
|
@ -61,4 +67,4 @@ class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
|
|
|
@ -14,13 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/eye_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_split_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triangle_matrix_copy_impl.cuh"
|
||||
|
@ -39,25 +40,27 @@ constexpr size_t kRowIndex = 2;
|
|||
constexpr size_t kColIndex = 1;
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class CholeskyGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
using pointer = T *;
|
||||
CholeskyGpuKernelMod() = default;
|
||||
~CholeskyGpuKernelMod() = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
if (common::AnfAlgo::HasNodeAttr("upper", kernel_node)) {
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCholeskyInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCholeskyOutputsNum, kernel_name_);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
if (base_operator->HasAttr("upper")) {
|
||||
flag_ = false;
|
||||
upper_ = static_cast<bool>(GetAttr<bool>(kernel_node, "upper"));
|
||||
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
|
||||
}
|
||||
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
if (common::AnfAlgo::HasNodeAttr(kClean, kernel_node)) {
|
||||
clean_ = static_cast<bool>(GetAttr<bool>(kernel_node, kClean));
|
||||
if (base_operator->HasAttr(kClean)) {
|
||||
clean_ = GetValue<bool>(base_operator->GetAttr(kClean));
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kLower, kernel_node)) {
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
|
||||
if (base_operator->HasAttr(kLower)) {
|
||||
lower_ = GetValue<bool>(base_operator->GetAttr(kLower));
|
||||
}
|
||||
// if clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
// Cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
|
||||
|
@ -77,18 +80,23 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
// Get CuSolver Dense: matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
return true;
|
||||
}
|
||||
|
||||
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
|
||||
if (IsDynamic(shape_signed)) {
|
||||
return true;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto in_shape = Convert2SizeTClipNeg(shape_signed);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
|
||||
auto in_shape = LongVecToSizeVec(inputs[kInputIndex]->GetShapeVector());
|
||||
if (!InitNoSplitDim(in_shape)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
return InitNoSplitDim(in_shape);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
@ -102,7 +110,8 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
bool InitNoSplitDim(const std::vector<size_t> &shape) {
|
||||
constexpr size_t min_dim = 1;
|
||||
if (shape.size() <= min_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape dim is " << shape.size() << " which is invalid.";
|
||||
MS_LOG(ERROR) << kernel_name_ << " input or output shape dim is " << shape.size() << " which is invalid.";
|
||||
return false;
|
||||
}
|
||||
cho_row_ = shape.at(shape.size() - kRowIndex);
|
||||
cho_col_ = shape.at(shape.size() - kColIndex);
|
||||
|
@ -111,9 +120,10 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
outer_batch_ *= shape.at(batch);
|
||||
}
|
||||
if (cho_row_ != cho_col_) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. "
|
||||
<< "Cholesky expects a square matrix. but input or output shape is: " << cho_row_ << ", "
|
||||
<< cho_col_;
|
||||
MS_LOG(ERROR) << kernel_name_ << " input shape is invalid. "
|
||||
<< "Cholesky expects a square matrix. but input or output shape is: " << cho_row_ << ", "
|
||||
<< cho_col_;
|
||||
return false;
|
||||
}
|
||||
// set matrix row or col to be lead dimension
|
||||
m_ = SizeToInt(cho_row_);
|
||||
|
@ -124,7 +134,7 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
return true;
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
void InitSizeLists() {
|
||||
size_t workspace_size = outer_batch_ * sizeof(pointer);
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
workspace_size = outer_batch_ * sizeof(int);
|
||||
|
@ -146,29 +156,29 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
// Copy input data to output, cholesky inplace output in gpu backend.
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input1_addr, outer_batch_ * m_ * lda_ * unit_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy input to output Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(output_addr, input1_addr, outer_batch_ * m_ * lda_ * unit_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
kernel_name_ + "cuda memcopy input to output Fail");
|
||||
|
||||
for (size_t i = 0; i < outer_batch_; i++) {
|
||||
h_array_[i] = output_addr + i * lda_ * m_;
|
||||
}
|
||||
|
||||
// Copy output's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * outer_batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
kernel_name_ + "cuda memcopy Fail");
|
||||
|
||||
// Solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
|
||||
"cusolver cholesky batched Fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
|
@ -189,7 +199,6 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
int res_dim_{0};
|
||||
bool is_null_input_{false};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
std::vector<pointer> h_array_;
|
||||
|
@ -201,4 +210,4 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_
|
||||
|
|
|
@ -33,6 +33,10 @@ abstract::ShapePtr CholeskyInferShape(const PrimitivePtr &primitive, const std::
|
|||
auto x_shape = input_x->shape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape);
|
||||
auto const &x_shape_list = x_shape->shape();
|
||||
// support dynamic rank
|
||||
if (IsDynamicRank(x_shape_list)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
const size_t x_dim = x_shape_list.size();
|
||||
constexpr size_t kDefaultRank = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
|
@ -42,6 +46,10 @@ abstract::ShapePtr CholeskyInferShape(const PrimitivePtr &primitive, const std::
|
|||
<< "equal to 2"
|
||||
<< ", but got a " << x_dim << "-D Tensor.";
|
||||
}
|
||||
// support dynamic shape
|
||||
if (IsDynamic(x_shape_list)) {
|
||||
return input_x->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
if (x_shape_list[x_dim - kColIndex] != x_shape_list[x_dim - kRowIndex]) {
|
||||
MS_EXCEPTION(ValueError) << "For Cholesky, input x must be batch squares"
|
||||
<< ", but got batch " << x_shape_list[x_dim - kRowIndex] << " x "
|
||||
|
@ -58,6 +66,12 @@ TypePtr CholeskyInferType(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void Cholesky::set_upper(const bool upper) { (void)this->AddAttr("upper", api::MakeValue(upper)); }
|
||||
|
||||
bool Cholesky::get_upper() const { return GetValue<bool>(GetAttr("upper")); }
|
||||
|
||||
void Cholesky::Init(const bool upper) { set_upper(upper); }
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Cholesky, BaseOperator);
|
||||
AbstractBasePtr CholeskyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -32,6 +32,15 @@ class MIND_API Cholesky : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(Cholesky);
|
||||
Cholesky() : BaseOperator(kNameCholesky) { InitIOName({"input_x"}, {"output"}); }
|
||||
|
||||
void Init(const bool upper = false);
|
||||
/// \brief Set upper.
|
||||
void set_upper(const bool upper);
|
||||
|
||||
/// \brief Get upper.
|
||||
///
|
||||
/// \return upper.
|
||||
bool get_upper() const;
|
||||
};
|
||||
abstract::AbstractBasePtr CholeskyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
Loading…
Reference in New Issue