!29502 add matrix_set_diag of gpu backend

Merge pull request !29502 from zhuzhongrui/pub_master3
This commit is contained in:
i-robot 2022-01-26 01:53:42 +00:00 committed by Gitee
commit 6c1ccb092e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 369 additions and 15 deletions

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/matrix_set_diag_gpu_kernel.h"
#include <algorithm>
#include <tuple>
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MatrixSetDiagGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
MatrixSetDiagGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
MatrixSetDiagGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
MatrixSetDiagGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,184 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include <utility>
#include "backend/kernel_compiler/gpu//gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_set_diag_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class MatrixSetDiagGpuKernelMod : public NativeGpuKernelMod {
public:
MatrixSetDiagGpuKernelMod() = default;
~MatrixSetDiagGpuKernelMod() override = default;
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
constexpr size_t required_input_nums = 3;
constexpr size_t required_output_nums = 1;
if (AnfAlgo::GetInputNum(kernel_node) != required_input_nums ||
AnfAlgo::GetOutputTensorNum(kernel_node) != required_output_nums) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the input nums are required to [input, diagonal, "
"k, alignment] for 3 and the output nums is require to 1.";
}
// invalid alignment will throw an exception.
auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAlignment);
alignment_ = GetAlignments(alignment);
constexpr int input_index = 0;
constexpr int diag_index = 1;
constexpr int diag_k_index = 2;
constexpr int output_index = 0;
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index);
auto diag_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_index);
auto diag_k_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_k_index);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input_shape") ||
CHECK_SHAPE_NULL(diag_shape, kernel_name_, "diag_shape") ||
CHECK_SHAPE_NULL(diag_k_shape, kernel_name_, "diag_k_shape") ||
CHECK_SHAPE_NULL(output_shape, kernel_name_, "output_shape");
if (is_null_input_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the input shape contains empty, which is invalid, please check, it's input.";
}
constexpr int temporary_2d_dim = 2;
constexpr int temporary_1d_dim = 1;
if (SizeToInt(input_shape.size()) < temporary_2d_dim || SizeToInt(diag_shape.size()) < temporary_1d_dim ||
input_shape != output_shape) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of input is invalid for input shape greater than 2D, diag shape "
"greater than 1D, input shape should equal to output shape.";
}
if (SizeToInt(diag_k_shape.size()) != temporary_1d_dim) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of diag_region's dim should be limited to range (k[0],k[1]).";
}
int input_rank = SizeToInt(input_shape.size());
for (int i = 0; i < input_rank - temporary_2d_dim; ++i) {
outer_batch_ *= SizeToInt(input_shape.at(i));
}
inner_rows_ = SizeToInt(input_shape.at(input_rank - temporary_2d_dim));
inner_cols_ = SizeToInt(input_shape.at(input_rank - temporary_1d_dim));
expected_num_diags_ =
SizeToInt(diag_shape.size()) == input_rank ? SizeToInt(diag_shape.at(input_rank - temporary_2d_dim)) : 1;
for (const auto &diag_sh : diag_shape) {
diagonal_count_ *= diag_sh;
}
for (const auto &k_sh : diag_k_shape) {
k_count_ *= k_sh;
}
InitSizeLists();
return true;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input = inputs.at(kDim0);
auto diag = inputs.at(kDim1);
constexpr int diag_k_index = 2;
auto k = inputs.at(diag_k_index);
auto output = outputs.at(kDim0);
T *input_addr = reinterpret_cast<T *>(input->addr);
T *diag_addr = reinterpret_cast<T *>(diag->addr);
int *diag_k_addr = reinterpret_cast<int *>(k->addr);
T *output_addr = reinterpret_cast<T *>(output->addr);
std::vector<int> host_k_vec(diag_k_index, 0);
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(host_k_vec.data(), diag_k_addr, k->size, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"matrix_set_diag cuda memcopy device to host Fail");
lower_diag_index_ = host_k_vec.at(kDim0);
upper_diag_index_ = host_k_vec.at(kDim1);
is_single_diag_ = (lower_diag_index_ == upper_diag_index_);
if (lower_diag_index_ <= -inner_rows_ || lower_diag_index_ >= inner_cols_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of diag_region's lower_diag_index is invalid, which must be between "
<< -inner_rows_ << " and " << inner_cols_;
}
if (upper_diag_index_ <= -inner_rows_ || upper_diag_index_ >= inner_cols_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of diag_region's upper_diag_index is invalid, which must be between "
<< -inner_rows_ << " and " << inner_cols_;
}
if (lower_diag_index_ > upper_diag_index_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of diag_region's lower_diag_index_ must less than upper_diag_index "
<< lower_diag_index_ << " < " << upper_diag_index_;
}
num_diags_ = upper_diag_index_ - lower_diag_index_ + 1;
if (lower_diag_index_ != upper_diag_index_ && expected_num_diags_ != num_diags_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of diag_region's lower_diag_index and upper_diag_index are not consistent "
"with input shape.";
}
max_diag_len_ =
std::min(inner_rows_ + std::min(upper_diag_index_, 0), inner_cols_ - std::max(lower_diag_index_, 0));
// copy input to output first, then set diagonal value to output.
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(output_addr, input_addr, input->size, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"matrix_set_diag cuda memcopy input to output Fail");
bool right_align_super_diagonal = (alignment_.first == MatrixDiag::RIGHT);
bool right_align_sub_diagonal = (alignment_.second == MatrixDiag::RIGHT);
MatrixSetDiag(outer_batch_, inner_rows_, inner_cols_, num_diags_, max_diag_len_, lower_diag_index_,
upper_diag_index_, right_align_super_diagonal, right_align_sub_diagonal, is_single_diag_, diag_addr,
output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
private:
void InitSizeLists() override {
input_size_list_.emplace_back(outer_batch_ * inner_rows_ * inner_cols_ * sizeof(T));
input_size_list_.emplace_back(diagonal_count_ * sizeof(T));
input_size_list_.emplace_back(k_count_ * sizeof(int));
output_size_list_.emplace_back(outer_batch_ * inner_rows_ * inner_cols_ * sizeof(T));
};
int lower_diag_index_{0};
int upper_diag_index_{0};
int inner_rows_{0};
int inner_cols_{0};
int num_diags_{0};
int expected_num_diags_{0};
int max_diag_len_{0};
int outer_batch_{1};
size_t diagonal_count_{1};
size_t k_count_{1};
bool is_single_diag_{true};
bool is_null_input_{true};
// <super_matrix_diag_align, sub_matrix_diag_align>
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_

View File

@ -0,0 +1,92 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "matrix_set_diag_impl.cuh"
#include <stdint.h>
#include <cuda_runtime.h>
#include <algorithm>
__inline__ __device__ int CalDiagOffset(int d, int max_diag_len, const int inner_row, const int inner_col,
const bool right_align_super_diagonal, const bool right_align_sub_diagonal) {
const bool right_align = (d >= 0 && right_align_super_diagonal) || (d <= 0 && right_align_sub_diagonal);
const int diag_len = std::min(inner_row + std::min(0, d), inner_col - std::max(0, d));
const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
return offset;
}
template <typename T>
__global__ void MatrixSetDiagKernel(const int outer_batch, const int inner_row, const int inner_col,
const int num_diags, const int max_diag_len, const int lower_index,
const int upper_index, const bool right_align_super_diagonal,
const bool right_align_sub_diagonal, const bool is_single_diag, const T *diag_addr,
T *output_addr) {
int count = outer_batch * inner_row * inner_col;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
int batch = i / (inner_row * inner_col);
int row = (i - batch * inner_row * inner_col) / inner_col;
int col = (i - batch * inner_row * inner_col) % inner_col;
int d = static_cast<int>(col - row);
if (is_single_diag) {
if (d == upper_index) {
output_addr[i] = diag_addr[batch * max_diag_len + col - std::max(upper_index, 0)];
}
} else {
int diag_index = upper_index - d;
int offset =
CalDiagOffset(d, max_diag_len, inner_row, inner_col, right_align_super_diagonal, right_align_sub_diagonal);
int index_in_diag = col - std::max(d, 0) + offset;
if (d >= lower_index && d <= upper_index) {
output_addr[i] = diag_addr[batch * num_diags * max_diag_len + diag_index * max_diag_len + index_in_diag];
}
}
}
}
template <typename T>
void MatrixSetDiag(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
const int max_diag_len, const int lower_index, const int upper_index,
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
const bool is_single_diag, const T *diag_addr, T *output_addr, cudaStream_t cuda_stream) {
int count = outer_batch * inner_row * inner_col;
MatrixSetDiagKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(
outer_batch, inner_row, inner_col, num_diags, max_diag_len, lower_index, upper_index, right_align_super_diagonal,
right_align_sub_diagonal, is_single_diag, diag_addr, output_addr);
return;
}
template void MatrixSetDiag<int>(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
const int max_diag_len, const int lower_index, const int upper_index,
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
const bool is_single_diag, const int *diag_addr, int *output_addr,
cudaStream_t cuda_stream);
template void MatrixSetDiag<int64_t>(const int outer_batch, const int inner_row, const int inner_col,
const int num_diags, const int max_diag_len, const int lower_index,
const int upper_index, const bool right_align_super_diagonal,
const bool right_align_sub_diagonal, const bool is_single_diag,
const int64_t *diag_addr, int64_t *output_addr, cudaStream_t cuda_stream);
template void MatrixSetDiag<float>(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
const int max_diag_len, const int lower_index, const int upper_index,
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
const bool is_single_diag, const float *diag_addr, float *output_addr,
cudaStream_t cuda_stream);
template void MatrixSetDiag<double>(const int outer_batch, const int inner_row, const int inner_col,
const int num_diags, const int max_diag_len, const int lower_index,
const int upper_index, const bool right_align_super_diagonal,
const bool right_align_sub_diagonal, const bool is_single_diag,
const double *diag_addr, double *output_addr, cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void MatrixSetDiag(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
const int max_diag_len, const int lower_index, const int upper_index,
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
const bool is_single_diag, const T *diag_addr, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_

View File

@ -48,6 +48,9 @@ static constexpr char kClean[] = "clean";
// Used by cholesky
static constexpr char kSplitDim[] = "split_dim";
// Used by MatrixSetDiag
static constexpr char kAlignment[] = "alignment";
// Used by MaxPool pad: The minimum value of float32
static constexpr float kSignedMinFloat = -3.402823466e+38F;

View File

@ -16,6 +16,7 @@
from .. import numpy as mnp
from .ops import MatrixSetDiag
from ..common import dtype as mstype
from .utils import _to_tensor
from .utils_const import _raise_value_error
@ -68,5 +69,6 @@ def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
k_vec = k
else:
_raise_value_error("input k to indicate diagonal region is invalid.")
k_vec = _to_tensor(k_vec, dtype=mstype.int32)
output = matrix_set_diag_net(input_x, diagonal, k_vec)
return output

View File

@ -293,13 +293,14 @@ def fat_cases(align=None):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [onp.int])
def test_matrix_set_diag(data_type):
"""
Feature: ALL TO ALL
Description: test geneal matrix cases for matrix_set_diag in pynative mode
Description: test geneal matrix cases for matrix_set_diag in pynative or graph mode
Expectation: the result match expected_diag_matrix.
"""
onp.random.seed(0)
@ -312,20 +313,8 @@ def test_matrix_set_diag(data_type):
expected_diag_matrix = input_mat * mask + banded_mat[0]
output = ops_wrapper.matrix_set_diag(
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
match_matrix(output, Tensor(expected_diag_matrix))
match_matrix(output.astype(onp.float64), Tensor(expected_diag_matrix))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [onp.int])
def test_graph_matrix_set_diag(data_type):
"""
Feature: ALL TO ALL
Description: test general matrix cases for matrix_set_diag in graph mode
Expectation: the result match expected_diag_matrix.
"""
onp.random.seed(0)
context.set_context(mode=context.GRAPH_MODE)
for align in ALIGNMENT_LIST:
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
@ -335,7 +324,7 @@ def test_graph_matrix_set_diag(data_type):
expected_diag_matrix = input_mat * mask + banded_mat[0]
output = ops_wrapper.matrix_set_diag(
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
match_matrix(output, Tensor(expected_diag_matrix))
match_matrix(output.astype(onp.float64), Tensor(expected_diag_matrix))
@pytest.mark.level0