!1521 Gpu Matmul & BiasAdd kernel support null input

Merge pull request !1521 from chenweifeng/empty_input
This commit is contained in:
mindspore-ci-bot 2020-05-28 20:46:35 +08:00 committed by Gitee
commit 8db919fafc
2 changed files with 25 additions and 1 deletions

View File

@ -35,7 +35,8 @@ class BiasAddGpuKernel : public GpuKernel {
cudnn_data_type_(CUDNN_DATA_FLOAT),
x_desc_(nullptr),
b_desc_(nullptr),
op_desc_(nullptr) {}
op_desc_(nullptr),
is_null_input_(false) {}
~BiasAddGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -45,6 +46,10 @@ class BiasAddGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *b_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
@ -65,6 +70,13 @@ class BiasAddGpuKernel : public GpuKernel {
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto num_dims = x_shape.size();
is_null_input_ = CHECK_NULL_INPUT(x_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "input is null";
InitSizeLists();
return true;
}
if (num_dims < 2) {
MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims;
}
@ -126,6 +138,7 @@ class BiasAddGpuKernel : public GpuKernel {
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t b_desc_;
cudnnOpTensorDescriptor_t op_desc_;
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -35,6 +35,7 @@ class MatMulGpuKernel : public GpuKernel {
m_(0),
n_(0),
k_(0),
is_null_input_(false),
transpose_x1_(CUBLAS_OP_N),
transpose_x2_(CUBLAS_OP_N),
handle_(nullptr),
@ -51,6 +52,9 @@ class MatMulGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);
@ -82,6 +86,12 @@ class MatMulGpuKernel : public GpuKernel {
dtype_b_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))];
dtype_c_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))];
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(output_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "input is null";
InitSizeLists();
return true;
}
auto dims = output_shape.size();
if (dims < 2) {
MS_LOG(EXCEPTION) << "Output dims " << dims << " not support.";
@ -125,6 +135,7 @@ class MatMulGpuKernel : public GpuKernel {
size_t m_;
size_t n_;
size_t k_;
bool is_null_input_;
cublasOperation_t transpose_x1_;
cublasOperation_t transpose_x2_;