cholesky dynamic shape

This commit is contained in:
mengyuanli 2022-10-12 16:21:59 +08:00
parent 01ca5e1555
commit bcc4a9b9a1
5 changed files with 118 additions and 71 deletions

View File

@ -53,36 +53,45 @@ void CholeskyCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size
outer_batch_ /= ((*row) * (*col));
}
void CholeskyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
auto input_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex));
InitMatrixInfo(input_shape, &input_row_, &input_col_);
auto output_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex));
InitMatrixInfo(output_shape, &output_row_, &output_col_);
if (common::AnfAlgo::HasNodeAttr("upper", kernel_node)) {
flag_ = false;
upper_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "upper");
}
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
if (common::AnfAlgo::HasNodeAttr(CLEAN, kernel_node)) {
clean_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
}
if (common::AnfAlgo::HasNodeAttr(LOWER, kernel_node)) {
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
bool CholeskyCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->GetPrim()->name();
dtype_ = inputs[kIndex0]->GetDtype();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Cholesky does not support this kernel data type: " << kernel_attr;
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
if (base_operator->HasAttr("upper")) {
flag_ = false;
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
}
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
if (base_operator->HasAttr(CLEAN)) {
clean_ = GetValue<bool>(base_operator->GetAttr(CLEAN));
}
if (base_operator->HasAttr(LOWER)) {
lower_ = GetValue<bool>(base_operator->GetAttr(LOWER));
}
return true;
}
int CholeskyCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto input_shape = LongVecToSizeVec(inputs[kInputIndex]->GetShapeVector());
InitMatrixInfo(input_shape, &input_row_, &input_col_);
auto output_shape = LongVecToSizeVec(inputs[kOutputIndex]->GetShapeVector());
InitMatrixInfo(output_shape, &output_row_, &output_col_);
return KRET_OK;
}
template <typename T>

View File

@ -14,21 +14,27 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class CholeskyCpuKernelMod : public NativeCpuKernelMod {
public:
CholeskyCpuKernelMod() = default;
~CholeskyCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
@ -61,4 +67,4 @@ class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_EIGEN_CPU_CHOLESKY_CPU_KERNEL_H_

View File

@ -14,13 +14,14 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <algorithm>
#include <map>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/eye_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_split_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triangle_matrix_copy_impl.cuh"
@ -39,25 +40,27 @@ constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
template <typename T>
class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class CholeskyGpuKernelMod : public NativeGpuKernelMod {
public:
using pointer = T *;
CholeskyGpuKernelMod() = default;
~CholeskyGpuKernelMod() = default;
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
if (common::AnfAlgo::HasNodeAttr("upper", kernel_node)) {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCholeskyInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCholeskyOutputsNum, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
if (base_operator->HasAttr("upper")) {
flag_ = false;
upper_ = static_cast<bool>(GetAttr<bool>(kernel_node, "upper"));
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
}
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
if (common::AnfAlgo::HasNodeAttr(kClean, kernel_node)) {
clean_ = static_cast<bool>(GetAttr<bool>(kernel_node, kClean));
if (base_operator->HasAttr(kClean)) {
clean_ = GetValue<bool>(base_operator->GetAttr(kClean));
}
if (common::AnfAlgo::HasNodeAttr(kLower, kernel_node)) {
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
if (base_operator->HasAttr(kLower)) {
lower_ = GetValue<bool>(base_operator->GetAttr(kLower));
}
// if clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
// Cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
@ -77,18 +80,23 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
// Get CuSolver Dense: matrix handler
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
return true;
}
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
if (IsDynamic(shape_signed)) {
return true;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto in_shape = Convert2SizeTClipNeg(shape_signed);
is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
input_size_list_.clear();
output_size_list_.clear();
auto in_shape = LongVecToSizeVec(inputs[kInputIndex]->GetShapeVector());
if (!InitNoSplitDim(in_shape)) {
return KRET_RESIZE_FAILED;
}
return InitNoSplitDim(in_shape);
return KRET_OK;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -102,7 +110,8 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
bool InitNoSplitDim(const std::vector<size_t> &shape) {
constexpr size_t min_dim = 1;
if (shape.size() <= min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape dim is " << shape.size() << " which is invalid.";
MS_LOG(ERROR) << kernel_name_ << " input or output shape dim is " << shape.size() << " which is invalid.";
return false;
}
cho_row_ = shape.at(shape.size() - kRowIndex);
cho_col_ = shape.at(shape.size() - kColIndex);
@ -111,9 +120,10 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
outer_batch_ *= shape.at(batch);
}
if (cho_row_ != cho_col_) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. "
<< "Cholesky expects a square matrix. but input or output shape is: " << cho_row_ << ", "
<< cho_col_;
MS_LOG(ERROR) << kernel_name_ << " input shape is invalid. "
<< "Cholesky expects a square matrix. but input or output shape is: " << cho_row_ << ", "
<< cho_col_;
return false;
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(cho_row_);
@ -124,7 +134,7 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
return true;
}
void InitSizeLists() override {
void InitSizeLists() {
size_t workspace_size = outer_batch_ * sizeof(pointer);
workspace_size_list_.emplace_back(workspace_size);
workspace_size = outer_batch_ * sizeof(int);
@ -146,29 +156,29 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
// Copy input data to output, cholesky inplace output in gpu backend.
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(output_addr, input1_addr, outer_batch_ * m_ * lda_ * unit_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy input to output Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(output_addr, input1_addr, outer_batch_ * m_ * lda_ * unit_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
kernel_name_ + "cuda memcopy input to output Fail");
for (size_t i = 0; i < outer_batch_; i++) {
h_array_[i] = output_addr + i * lda_ * m_;
}
// Copy output's addr to d_array_addr
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * outer_batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
kernel_name_ + "cuda memcopy Fail");
// Solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
"cusolver cholesky batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
"cusolver cholesky batched Fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
@ -189,7 +199,6 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
size_t lda_{0};
size_t ldb_{0};
int res_dim_{0};
bool is_null_input_{false};
cusolverDnHandle_t handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
std::vector<pointer> h_array_;
@ -201,4 +210,4 @@ class CholeskyGpuKernelMod : public DeprecatedNativeGpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CHOLESKY_GPU_KERNEL_H_

View File

@ -33,6 +33,10 @@ abstract::ShapePtr CholeskyInferShape(const PrimitivePtr &primitive, const std::
auto x_shape = input_x->shape();
MS_EXCEPTION_IF_NULL(x_shape);
auto const &x_shape_list = x_shape->shape();
// support dynamic rank
if (IsDynamicRank(x_shape_list)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
}
const size_t x_dim = x_shape_list.size();
constexpr size_t kDefaultRank = 2;
constexpr size_t kRowIndex = 2;
@ -42,6 +46,10 @@ abstract::ShapePtr CholeskyInferShape(const PrimitivePtr &primitive, const std::
<< "equal to 2"
<< ", but got a " << x_dim << "-D Tensor.";
}
// support dynamic shape
if (IsDynamic(x_shape_list)) {
return input_x->BuildShape()->cast<abstract::ShapePtr>();
}
if (x_shape_list[x_dim - kColIndex] != x_shape_list[x_dim - kRowIndex]) {
MS_EXCEPTION(ValueError) << "For Cholesky, input x must be batch squares"
<< ", but got batch " << x_shape_list[x_dim - kRowIndex] << " x "
@ -58,6 +66,12 @@ TypePtr CholeskyInferType(const PrimitivePtr &primitive, const std::vector<Abstr
}
} // namespace
void Cholesky::set_upper(const bool upper) { (void)this->AddAttr("upper", api::MakeValue(upper)); }
bool Cholesky::get_upper() const { return GetValue<bool>(GetAttr("upper")); }
void Cholesky::Init(const bool upper) { set_upper(upper); }
MIND_API_OPERATOR_IMPL(Cholesky, BaseOperator);
AbstractBasePtr CholeskyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {

View File

@ -32,6 +32,15 @@ class MIND_API Cholesky : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Cholesky);
Cholesky() : BaseOperator(kNameCholesky) { InitIOName({"input_x"}, {"output"}); }
void Init(const bool upper = false);
/// \brief Set upper.
void set_upper(const bool upper);
/// \brief Get upper.
///
/// \return upper.
bool get_upper() const;
};
abstract::AbstractBasePtr CholeskyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);