diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h index 66bcbf9149c..6e113325975 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H -#define MINDSPORE_MATMUL_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H #include #include @@ -30,19 +30,7 @@ namespace kernel { template class MatMulGpuKernel : public GpuKernel { public: - MatMulGpuKernel() - : batch_(0), - m_(0), - n_(0), - k_(0), - is_null_input_(false), - transpose_x1_(CUBLAS_OP_N), - transpose_x2_(CUBLAS_OP_N), - handle_(nullptr), - dtype_a_(CUDA_R_32F), - dtype_b_(CUDA_R_32F), - dtype_c_(CUDA_R_32F), - algo_(CUBLAS_GEMM_DEFAULT) {} + MatMulGpuKernel() { ResetResource(); } ~MatMulGpuKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -122,6 +110,24 @@ class MatMulGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + batch_ = 0; + m_ = 0; + n_ = 0; + k_ = 0; + is_null_input_ = false; + transpose_x1_ = CUBLAS_OP_N; + transpose_x2_ = CUBLAS_OP_N; + handle_ = nullptr; + dtype_a_ = CUDA_R_32F; + dtype_b_ = CUDA_R_32F; + dtype_c_ = CUDA_R_32F; + algo_ = CUBLAS_GEMM_DEFAULT; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitSizeLists() override { size_t unit_size = sizeof(T); @@ -158,4 +164,4 @@ class MatMulGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index c183b78a380..7c7ade82c27 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -289,7 +289,10 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - +AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 1ee95a2ea3c..012e9ff9327 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -317,6 +317,7 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr std::make_shared(start->element(), std::make_shared(shape, min_shape, max_shape)); return ret; } + AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); @@ -326,5 +327,93 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri auto input = CheckArg(op_name, args_spec_list, 0); return input->Broaden(); } + +AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { + MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; + } + ValuePtr TAptr = primitive->GetAttr("transpose_a"); + ValuePtr TBptr = primitive->GetAttr("transpose_b"); + bool TA = GetValue(TAptr); + bool TB = GetValue(TBptr); + ShapeVector x_min_shape = x->shape()->min_shape(); + ShapeVector x_max_shape = x->shape()->max_shape(); + ShapeVector y_min_shape = y->shape()->min_shape(); + ShapeVector y_max_shape = y->shape()->max_shape(); + (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); + (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); + ShapeVector ret_shape; + ShapeVector ret_min_shape; + ShapeVector ret_max_shape; + auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { + output.push_back(xshp[(TA ? 1 : 0)]); + output.push_back(yshp[(TB ? 0 : 1)]); + return; + }; + make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); + make_shape(ret_min_shape, x_min_shape, y_min_shape); + make_shape(ret_max_shape, x_max_shape, y_max_shape); + return std::make_shared(x->element(), + std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); +} + +AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { + MS_LOG(EXCEPTION) + << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; + } + ValuePtr TAptr = primitive->GetAttr("transpose_a"); + ValuePtr TBptr = primitive->GetAttr("transpose_b"); + bool TA = GetValue(TAptr); + bool TB = GetValue(TBptr); + ShapeVector x_min_shape = x->shape()->min_shape(); + ShapeVector x_max_shape = x->shape()->max_shape(); + ShapeVector y_min_shape = y->shape()->min_shape(); + ShapeVector y_max_shape = y->shape()->max_shape(); + (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); + (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); + ShapeVector ret_shape; + ShapeVector ret_min_shape; + ShapeVector ret_max_shape; + auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { + for (size_t i = 0; i < xshp.size() - 2; i++) { + if (xshp[i] != yshp[i]) { + if (xshp[i] > 0 && yshp[i] > 0) { + MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << "."; + } + output.push_back(Shape::SHP_ANY); + } else { + output.push_back(xshp[i]); + } + } + size_t offset = xshp.size() - 2; + output.push_back(xshp[offset + (TA ? 1 : 0)]); + output.push_back(yshp[offset + (TB ? 0 : 1)]); + return; + }; + make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); + make_shape(ret_min_shape, x_min_shape, y_min_shape); + make_shape(ret_max_shape, x_max_shape, y_max_shape); + return std::make_shared(x->element(), + std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index b739f18eca8..7defa35a14f 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -49,6 +49,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, {prim::kPrimLinSpace, {InferImplLinSpace, true}}, {prim::kPrimAddN, {InferImplAddN, true}}, + {prim::kPrimMatMul, {InferImplMatMul, true}}, + {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 444817074c1..62962c716d0 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -689,7 +689,7 @@ class CumProd(PrimitiveWithInfer): raise ValueError(f"For {self.name}, axis must be const.") -class MatMul(PrimitiveWithInfer): +class MatMul(PrimitiveWithCheck): """ Multiplies matrix `a` and matrix `b`. @@ -730,10 +730,10 @@ class MatMul(PrimitiveWithInfer): def check_shape_size(self, x1, x2): if len(x1) != 2 or len(x2) != 2: - raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and ' + raise ValueError('P.MatMul inputs x1, x2 should have the same dimension size and ' + f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).') - def infer_shape(self, x1, x2): + def check_shape(self, x1, x2): self.check_shape_size(x1, x2) cls_name = self.name # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two @@ -747,23 +747,18 @@ class MatMul(PrimitiveWithInfer): x2_last = x2[-2:] x1_col = x1_last[not self.transpose_a] x2_row = x2_last[self.transpose_b] - if x1_col != x2_row: - raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' - + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' - + f', x2 shape {x2}(transpose_b={self.transpose_b}).') + if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1): + if x1_col != x2_row: + raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' + + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' + + f', x2 shape {x2}(transpose_b={self.transpose_b}).') # set attribute self.add_prim_attr('transpose_x1', self.transpose_a) self.add_prim_attr('transpose_x2', self.transpose_b) - ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]] - return ret_dims - - def infer_dtype(self, x1, x2): + def check_dtype(self, x1, x2): args = {"x1": x1, "x2": x2} validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) - if x1.element_type() == mstype.int8: - return mstype.tensor_type(mstype.int32) - return x1 class BatchMatMul(MatMul): diff --git a/tests/st/ops/gpu/test_batch_matmul.py b/tests/st/ops/gpu/test_batch_matmul.py index 7dbab738011..0dc67ea804b 100644 --- a/tests/st/ops/gpu/test_batch_matmul.py +++ b/tests/st/ops/gpu/test_batch_matmul.py @@ -21,11 +21,9 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard class BatchMatMulNet(nn.Cell): def __init__(self, transpose_a=False, transpose_b=False): super(BatchMatMulNet, self).__init__() @@ -34,7 +32,9 @@ class BatchMatMulNet(nn.Cell): def construct(self, x, y): return self.batch_matmul(x, y) - +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_4d(): input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) @@ -140,3 +140,38 @@ def test_4D_fp16(): [[4340, 4396, 4456, 4510]], [[5816, 5880, 5948, 6016]]]]).astype(np.float16) assert (output.asnumpy() == expect).all() + + +class BatchMatMul_d(nn.Cell): + def __init__(self, transpose_a=False, transpose_b=False): + super(BatchMatMul_d, self).__init__() + self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b) + self.test_dynamic = inner.GpuConvertToDynamicShape() + + def construct(self, x, y): + x = self.test_dynamic(x) + y = self.test_dynamic(y) + return self.batch_matmul(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchmatmul_dynamic(): + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BatchMatMul_d() + + x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32) + y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32) + + output1 = net(Tensor(x1), Tensor(y1)) + expect1 = np.matmul(x1, y1) + assert (output1.asnumpy() == expect1).all() + + x2 = np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3).astype(np.float32) + y2 = np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4).astype(np.float32) + + output2 = net(Tensor(x2), Tensor(y2)) + expect2 = np.matmul(x2, y2) + assert (output2.asnumpy() == expect2).all() diff --git a/tests/st/ops/gpu/test_matmul_op.py b/tests/st/ops/gpu/test_matmul_op.py new file mode 100644 index 00000000000..59285ded43b --- /dev/null +++ b/tests/st/ops/gpu/test_matmul_op.py @@ -0,0 +1,54 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner + +class MatMul_d(nn.Cell): + def __init__(self): + super(MatMul_d, self).__init__() + self.test_dynamic = inner.GpuConvertToDynamicShape() + self.matmul = P.MatMul() + + def construct(self, x, y): + x = self.test_dynamic(x) + y = self.test_dynamic(y) + return self.matmul(x, y) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_MatMul_dynamic(): + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = MatMul_d() + + x1 = np.arange(2).reshape(1, 2).astype(np.float32) + y1 = np.arange(4).reshape(2, 2).astype(np.float32) + output1 = net(Tensor(x1), Tensor(y1)) + expect1 = np.matmul(x1, y1) + np.testing.assert_array_almost_equal(output1.asnumpy(), expect1) + + x2 = np.arange(102).reshape(34, 3).astype(np.float32) + y2 = np.arange(18).reshape(3, 6).astype(np.float32) + output2 = net(Tensor(x2), Tensor(y2)) + expect2 = np.matmul(x2, y2) + np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)