forked from mindspore-Ecosystem/mindspore
add dynamic shape and testcases to MatMul & BatchMatMul
This commit is contained in:
parent
bbdb5500e5
commit
c994eab732
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H
|
||||
#define MINDSPORE_MATMUL_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
@ -30,19 +30,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class MatMulGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MatMulGpuKernel()
|
||||
: batch_(0),
|
||||
m_(0),
|
||||
n_(0),
|
||||
k_(0),
|
||||
is_null_input_(false),
|
||||
transpose_x1_(CUBLAS_OP_N),
|
||||
transpose_x2_(CUBLAS_OP_N),
|
||||
handle_(nullptr),
|
||||
dtype_a_(CUDA_R_32F),
|
||||
dtype_b_(CUDA_R_32F),
|
||||
dtype_c_(CUDA_R_32F),
|
||||
algo_(CUBLAS_GEMM_DEFAULT) {}
|
||||
MatMulGpuKernel() { ResetResource(); }
|
||||
~MatMulGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -122,6 +110,24 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
batch_ = 0;
|
||||
m_ = 0;
|
||||
n_ = 0;
|
||||
k_ = 0;
|
||||
is_null_input_ = false;
|
||||
transpose_x1_ = CUBLAS_OP_N;
|
||||
transpose_x2_ = CUBLAS_OP_N;
|
||||
handle_ = nullptr;
|
||||
dtype_a_ = CUDA_R_32F;
|
||||
dtype_b_ = CUDA_R_32F;
|
||||
dtype_c_ = CUDA_R_32F;
|
||||
algo_ = CUBLAS_GEMM_DEFAULT;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t unit_size = sizeof(T);
|
||||
|
@ -158,4 +164,4 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
|
|
|
@ -289,7 +289,10 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
|
|
|
@ -317,6 +317,7 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
@ -326,5 +327,93 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
auto input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
return input->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(y);
|
||||
MS_EXCEPTION_IF_NULL(y->shape());
|
||||
if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
|
||||
}
|
||||
ValuePtr TAptr = primitive->GetAttr("transpose_a");
|
||||
ValuePtr TBptr = primitive->GetAttr("transpose_b");
|
||||
bool TA = GetValue<bool>(TAptr);
|
||||
bool TB = GetValue<bool>(TBptr);
|
||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||
ShapeVector y_min_shape = y->shape()->min_shape();
|
||||
ShapeVector y_max_shape = y->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
|
||||
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
|
||||
ShapeVector ret_shape;
|
||||
ShapeVector ret_min_shape;
|
||||
ShapeVector ret_max_shape;
|
||||
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
|
||||
output.push_back(xshp[(TA ? 1 : 0)]);
|
||||
output.push_back(yshp[(TB ? 0 : 1)]);
|
||||
return;
|
||||
};
|
||||
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
|
||||
make_shape(ret_min_shape, x_min_shape, y_min_shape);
|
||||
make_shape(ret_max_shape, x_max_shape, y_max_shape);
|
||||
return std::make_shared<AbstractTensor>(x->element(),
|
||||
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(y);
|
||||
MS_EXCEPTION_IF_NULL(y->shape());
|
||||
if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3.";
|
||||
}
|
||||
ValuePtr TAptr = primitive->GetAttr("transpose_a");
|
||||
ValuePtr TBptr = primitive->GetAttr("transpose_b");
|
||||
bool TA = GetValue<bool>(TAptr);
|
||||
bool TB = GetValue<bool>(TBptr);
|
||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||
ShapeVector y_min_shape = y->shape()->min_shape();
|
||||
ShapeVector y_max_shape = y->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
|
||||
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
|
||||
ShapeVector ret_shape;
|
||||
ShapeVector ret_min_shape;
|
||||
ShapeVector ret_max_shape;
|
||||
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
|
||||
for (size_t i = 0; i < xshp.size() - 2; i++) {
|
||||
if (xshp[i] != yshp[i]) {
|
||||
if (xshp[i] > 0 && yshp[i] > 0) {
|
||||
MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << ".";
|
||||
}
|
||||
output.push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
output.push_back(xshp[i]);
|
||||
}
|
||||
}
|
||||
size_t offset = xshp.size() - 2;
|
||||
output.push_back(xshp[offset + (TA ? 1 : 0)]);
|
||||
output.push_back(yshp[offset + (TB ? 0 : 1)]);
|
||||
return;
|
||||
};
|
||||
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
|
||||
make_shape(ret_min_shape, x_min_shape, y_min_shape);
|
||||
make_shape(ret_max_shape, x_max_shape, y_max_shape);
|
||||
return std::make_shared<AbstractTensor>(x->element(),
|
||||
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
{prim::kPrimAddN, {InferImplAddN, true}},
|
||||
{prim::kPrimMatMul, {InferImplMatMul, true}},
|
||||
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
|
||||
// Array
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
|
|
|
@ -689,7 +689,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
|
||||
|
||||
class MatMul(PrimitiveWithInfer):
|
||||
class MatMul(PrimitiveWithCheck):
|
||||
"""
|
||||
Multiplies matrix `a` and matrix `b`.
|
||||
|
||||
|
@ -730,10 +730,10 @@ class MatMul(PrimitiveWithInfer):
|
|||
|
||||
def check_shape_size(self, x1, x2):
|
||||
if len(x1) != 2 or len(x2) != 2:
|
||||
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
|
||||
raise ValueError('P.MatMul inputs x1, x2 should have the same dimension size and '
|
||||
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
|
||||
|
||||
def infer_shape(self, x1, x2):
|
||||
def check_shape(self, x1, x2):
|
||||
self.check_shape_size(x1, x2)
|
||||
cls_name = self.name
|
||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||
|
@ -747,23 +747,18 @@ class MatMul(PrimitiveWithInfer):
|
|||
x2_last = x2[-2:]
|
||||
x1_col = x1_last[not self.transpose_a]
|
||||
x2_row = x2_last[self.transpose_b]
|
||||
if x1_col != x2_row:
|
||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
||||
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
||||
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
||||
if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1):
|
||||
if x1_col != x2_row:
|
||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
||||
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
||||
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
||||
# set attribute
|
||||
self.add_prim_attr('transpose_x1', self.transpose_a)
|
||||
self.add_prim_attr('transpose_x2', self.transpose_b)
|
||||
|
||||
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
|
||||
return ret_dims
|
||||
|
||||
def infer_dtype(self, x1, x2):
|
||||
def check_dtype(self, x1, x2):
|
||||
args = {"x1": x1, "x2": x2}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
|
||||
if x1.element_type() == mstype.int8:
|
||||
return mstype.tensor_type(mstype.int32)
|
||||
return x1
|
||||
|
||||
|
||||
class BatchMatMul(MatMul):
|
||||
|
|
|
@ -21,11 +21,9 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMulNet, self).__init__()
|
||||
|
@ -34,7 +32,9 @@ class BatchMatMulNet(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
|
@ -140,3 +140,38 @@ def test_4D_fp16():
|
|||
[[4340, 4396, 4456, 4510]],
|
||||
[[5816, 5880, 5948, 6016]]]]).astype(np.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
class BatchMatMul_d(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMul_d, self).__init__()
|
||||
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
|
||||
self.test_dynamic = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.test_dynamic(x)
|
||||
y = self.test_dynamic(y)
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchmatmul_dynamic():
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMul_d()
|
||||
|
||||
x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32)
|
||||
y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32)
|
||||
|
||||
output1 = net(Tensor(x1), Tensor(y1))
|
||||
expect1 = np.matmul(x1, y1)
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
|
||||
x2 = np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3).astype(np.float32)
|
||||
y2 = np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4).astype(np.float32)
|
||||
|
||||
output2 = net(Tensor(x2), Tensor(y2))
|
||||
expect2 = np.matmul(x2, y2)
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
class MatMul_d(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MatMul_d, self).__init__()
|
||||
self.test_dynamic = inner.GpuConvertToDynamicShape()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.test_dynamic(x)
|
||||
y = self.test_dynamic(y)
|
||||
return self.matmul(x, y)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_MatMul_dynamic():
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = MatMul_d()
|
||||
|
||||
x1 = np.arange(2).reshape(1, 2).astype(np.float32)
|
||||
y1 = np.arange(4).reshape(2, 2).astype(np.float32)
|
||||
output1 = net(Tensor(x1), Tensor(y1))
|
||||
expect1 = np.matmul(x1, y1)
|
||||
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
|
||||
|
||||
x2 = np.arange(102).reshape(34, 3).astype(np.float32)
|
||||
y2 = np.arange(18).reshape(3, 6).astype(np.float32)
|
||||
output2 = net(Tensor(x2), Tensor(y2))
|
||||
expect2 = np.matmul(x2, y2)
|
||||
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)
|
Loading…
Reference in New Issue