forked from mindspore-Ecosystem/mindspore
!8595 Add Cholesky op at GPU back-end
From: @peixu_ren Reviewed-by: @zichun_ye,@wilfchen Signed-off-by:
This commit is contained in:
commit
eb7cbd49ac
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "triangle_matrix_copy_impl.cuh"
|
||||||
|
template <typename T>
|
||||||
|
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo,
|
||||||
|
const size_t count, const size_t ldb, const size_t m) {
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||||
|
size_t batchIdx = i / (ldb * m);
|
||||||
|
size_t row = (i - batchIdx * ldb * m) / m;
|
||||||
|
size_t col = (i - batchIdx * ldb * m) % m;
|
||||||
|
|
||||||
|
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0;
|
||||||
|
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0;
|
||||||
|
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||||
|
if (col > row) {
|
||||||
|
output[i] = 0;
|
||||||
|
} else {
|
||||||
|
output[i] = input[i];
|
||||||
|
}
|
||||||
|
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
|
||||||
|
if (col < row) {
|
||||||
|
output[i] = 0;
|
||||||
|
} else {
|
||||||
|
output[i] = input[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
|
||||||
|
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) {
|
||||||
|
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template void TriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo, const size_t count,
|
||||||
|
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||||
|
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||||
|
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,24 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
template <typename T>
|
||||||
|
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
|
||||||
|
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_
|
|
@ -0,0 +1,23 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
CholeskyGpuKernel, float)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,247 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
||||||
|
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
|
#include "utils/convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class CholeskyGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
|
||||||
|
~CholeskyGpuKernel() = default;
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
auto d_array_addr = GetDeviceAddress<T *>(workspace, 0);
|
||||||
|
auto d_identity_addr = GetDeviceAddress<T *>(workspace, 1);
|
||||||
|
if (!use_split_matrix) {
|
||||||
|
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
|
||||||
|
for (size_t i = 0; i < batch_; i++) {
|
||||||
|
h_array[i] = input1_addr + i * lda_ * m_;
|
||||||
|
h_identity[i] = output_addr + i * ldb_ * m_;
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(
|
||||||
|
cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_,
|
||||||
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
|
}
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_,
|
||||||
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_,
|
||||||
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
|
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||||
|
cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||||
|
"cusolver cholesky batched Fail");
|
||||||
|
TriangleMatrixCopy(input1_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
} else {
|
||||||
|
auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2);
|
||||||
|
auto d_batch_input_addr = GetDeviceAddress<T>(workspace, 3);
|
||||||
|
for (size_t i = 0; i < batch_; i++) {
|
||||||
|
h_array[i] = d_batch_input_addr + i * lda_ * m_;
|
||||||
|
h_identity[i] = output_addr + i * ldb_ * m_;
|
||||||
|
}
|
||||||
|
Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_,
|
||||||
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_,
|
||||||
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
|
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||||
|
cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_),
|
||||||
|
"cusolver cholesky batched Fail");
|
||||||
|
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||||
|
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||||
|
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
|
split_dim = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
|
||||||
|
if (split_dim == 0) {
|
||||||
|
InitNoSpltDim(in_shape);
|
||||||
|
} else {
|
||||||
|
InitSpltDim(in_shape);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitNoSpltDim(const std::vector<size_t> &in_shape) {
|
||||||
|
use_split_matrix = false;
|
||||||
|
if (in_shape.size() == 2) {
|
||||||
|
batch_ = 1;
|
||||||
|
if (in_shape[0] != in_shape[1]) {
|
||||||
|
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||||
|
}
|
||||||
|
} else if (in_shape.size() == 3) {
|
||||||
|
batch_ = SizeToInt(in_shape[0]);
|
||||||
|
if (in_shape[1] != in_shape[2]) {
|
||||||
|
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
|
||||||
|
}
|
||||||
|
|
||||||
|
m_ = SizeToInt(in_shape[1]);
|
||||||
|
lda_ = m_;
|
||||||
|
ldb_ = m_;
|
||||||
|
h_array.resize(batch_);
|
||||||
|
h_identity.resize(batch_);
|
||||||
|
h_identity_data.resize(m_ * m_);
|
||||||
|
for (size_t i = 0; i < m_; i++) {
|
||||||
|
for (size_t j = 0; j < m_; j++) {
|
||||||
|
if (i == j) {
|
||||||
|
h_identity_data[i * m_ + j] = 1;
|
||||||
|
} else {
|
||||||
|
h_identity_data[i * m_ + j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
InitSizeLists();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitSpltDim(const std::vector<size_t> &in_shape) {
|
||||||
|
if (in_shape.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
|
||||||
|
}
|
||||||
|
height = in_shape[0];
|
||||||
|
width = in_shape[1];
|
||||||
|
if (height != width) {
|
||||||
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
|
||||||
|
}
|
||||||
|
if (SizeToInt(height) <= split_dim) {
|
||||||
|
use_split_matrix = false;
|
||||||
|
batch_ = 1;
|
||||||
|
m_ = SizeToInt(in_shape[1]);
|
||||||
|
lda_ = m_;
|
||||||
|
ldb_ = m_;
|
||||||
|
h_array.resize(batch_);
|
||||||
|
h_identity.resize(batch_);
|
||||||
|
h_identity_data.resize(m_ * m_);
|
||||||
|
for (size_t i = 0; i < m_; i++) {
|
||||||
|
for (size_t j = 0; j < m_; j++) {
|
||||||
|
if (i == j) {
|
||||||
|
h_identity_data[i * m_ + j] = 1;
|
||||||
|
} else {
|
||||||
|
h_identity_data[i * m_ + j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
InitSizeLists();
|
||||||
|
} else {
|
||||||
|
use_split_matrix = true;
|
||||||
|
int batch = SizeToInt(in_shape[1]) / split_dim;
|
||||||
|
res_dim = in_shape[1] - batch * split_dim;
|
||||||
|
if (res_dim == 0) {
|
||||||
|
batch_ = batch;
|
||||||
|
} else {
|
||||||
|
batch_ = batch + 1;
|
||||||
|
}
|
||||||
|
m_ = split_dim;
|
||||||
|
lda_ = m_;
|
||||||
|
ldb_ = m_;
|
||||||
|
h_array.resize(batch_);
|
||||||
|
h_identity.resize(batch_);
|
||||||
|
h_identity_data.resize(m_ * m_);
|
||||||
|
for (size_t i = 0; i < m_; i++) {
|
||||||
|
for (size_t j = 0; j < m_; j++) {
|
||||||
|
if (i == j) {
|
||||||
|
h_identity_data[i * m_ + j] = 1;
|
||||||
|
} else {
|
||||||
|
h_identity_data[i * m_ + j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
InitSizeLists();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitSizeLists() override {
|
||||||
|
size_t unit_size = sizeof(T);
|
||||||
|
size_t input_size;
|
||||||
|
size_t workspace_size;
|
||||||
|
if (!use_split_matrix) {
|
||||||
|
input_size = batch_ * m_ * lda_ * unit_size;
|
||||||
|
} else {
|
||||||
|
input_size = height * width * unit_size;
|
||||||
|
workspace_size = batch_ * m_ * lda_ * unit_size;
|
||||||
|
workspace_size_list_.push_back(workspace_size);
|
||||||
|
}
|
||||||
|
input_size_list_.push_back(input_size);
|
||||||
|
size_t output_size = batch_ * m_ * lda_ * unit_size;
|
||||||
|
output_size_list_.push_back(output_size);
|
||||||
|
workspace_size = batch_ * sizeof(T *);
|
||||||
|
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||||
|
workspace_size = batch_ * sizeof(T *);
|
||||||
|
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||||
|
workspace_size = batch_ * sizeof(int);
|
||||||
|
workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t batch_;
|
||||||
|
size_t m_;
|
||||||
|
size_t lda_;
|
||||||
|
size_t ldb_;
|
||||||
|
int res_dim;
|
||||||
|
int split_dim;
|
||||||
|
bool is_null_input_;
|
||||||
|
bool use_split_matrix;
|
||||||
|
size_t height;
|
||||||
|
size_t width;
|
||||||
|
cusolverDnHandle_t handle_;
|
||||||
|
cublasHandle_t blas_handle_;
|
||||||
|
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
|
||||||
|
std::vector<T *> h_array;
|
||||||
|
std::vector<T *> h_identity;
|
||||||
|
std::vector<T> h_identity_data;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif
|
|
@ -88,7 +88,7 @@ from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingB
|
||||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||||
CusMatMulCubeDenseRight,
|
CusMatMulCubeDenseRight,
|
||||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, CholeskyTrsm, DetTriangle)
|
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle)
|
||||||
from .sparse_ops import SparseToDense
|
from .sparse_ops import SparseToDense
|
||||||
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx
|
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx
|
||||||
|
|
||||||
|
|
|
@ -608,9 +608,42 @@ class UpdateThorGradient(PrimitiveWithInfer):
|
||||||
return x2_dtype
|
return x2_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class Cholesky(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Inner API for positive-definite matrix Cholesky decomposition GPU backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, split_dim=0):
|
||||||
|
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||||
|
self.split_dim = split_dim
|
||||||
|
self.add_prim_attr('split_dim', self.split_dim)
|
||||||
|
|
||||||
|
def infer_shape(self, x1_shape):
|
||||||
|
if self.split_dim != 0:
|
||||||
|
assert len(x1_shape) == 2
|
||||||
|
height = x1_shape[0]
|
||||||
|
width = x1_shape[1]
|
||||||
|
assert height == width
|
||||||
|
if height <= self.split_dim:
|
||||||
|
out_shape = [1, height, width]
|
||||||
|
else:
|
||||||
|
batch = height // self.split_dim
|
||||||
|
if height != batch * self.split_dim:
|
||||||
|
batch += 1
|
||||||
|
out_shape = [batch, self.split_dim, self.split_dim]
|
||||||
|
else:
|
||||||
|
out_shape = x1_shape
|
||||||
|
return out_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x1_dtype):
|
||||||
|
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
|
||||||
|
return x1_dtype
|
||||||
|
|
||||||
|
|
||||||
class CholeskyTrsm(PrimitiveWithInfer):
|
class CholeskyTrsm(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Inner API for resnet50 THOR GPU backend
|
Inner API for resnet50 THOR GPU backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -643,7 +676,23 @@ class CholeskyTrsm(PrimitiveWithInfer):
|
||||||
|
|
||||||
class DetTriangle(PrimitiveWithInfer):
|
class DetTriangle(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Calculate the determinant of triangle matrices
|
Calculate the determinant of triangle matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fill_mode (tuple): The target shape to broadcast.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the given `shape` and the same data type as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (2, 3)
|
||||||
|
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||||
|
>>> broadcast_to = P.BroadcastTo(shape)
|
||||||
|
>>> broadcast_to(input_x)
|
||||||
|
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
class NetCholesky(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetCholesky, self).__init__()
|
||||||
|
self.cholesky = P.Cholesky()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.cholesky(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cholesky_fp32():
|
||||||
|
cholesky = NetCholesky()
|
||||||
|
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
|
||||||
|
output = cholesky(Tensor(x, dtype=mstype.float32))
|
||||||
|
expect = np.linalg.cholesky(x)
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
Loading…
Reference in New Issue