support spmm

This commit is contained in:
muxiyin 2022-06-09 10:04:30 +08:00
parent 415241381b
commit 27e96dd46d
21 changed files with 239 additions and 8 deletions

View File

@ -83,6 +83,19 @@ mindspore.CSRTensor
Tensor。
.. py:method:: mm(dense_matrix)
返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。
形状为 `[M, N]` 的CSRTensor需要适配形状为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。
**参数:**
- **dense_matrix** (Tensor) - 形状为 `[NK]` 的二维矩阵其中N等于CSRTensor的列数。
**返回:**
Tensor。
.. py:method:: ndim
:property:

View File

@ -60,6 +60,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(prim::kPrimCSRReduceSum->name(), {3, 4});
Register(prim::kPrimCSRMV->name(), {3});
Register(prim::kPrimCSRMM->name(), {3});
Register(prim::kPrimCSRMul->name(), {3});
Register(prim::kPrimCSRDiv->name(), {3});
Register(prim::kPrimCSRGather->name(), {3});

View File

@ -277,6 +277,7 @@ BuiltInTypeMap &GetMethodMap() {
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
{"mm", std::string("csr_mm")}, // C.csr_mm
}},
{kObjectTypeCOOTensorType,
{

View File

@ -344,8 +344,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
bool IsAKGSparseOP(const AnfNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv};
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
return IsOneOfPrimitiveCNode(cnode, prims);
}
} // namespace mindspore

View File

@ -161,6 +161,8 @@ AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const Primiti
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -1033,5 +1033,48 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
AbstractBasePtrList rets = {add1, add0, sub0};
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a sparse tensor and a dense tensor.
constexpr auto kCSRMMInputsNum = 5;
constexpr auto kCSRMMShapeSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRMMInputsNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
auto dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
<< "and csr tensor's shape[1] is" << sparse_shape[kIndexOne];
}
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
} // namespace abstract
} // namespace mindspore

View File

@ -286,6 +286,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}},
{prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}},
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
{prim::kPrimCSRMM, R{InferImplCSRMM, nullptr, true}},
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},

View File

@ -219,6 +219,7 @@ constexpr auto kCOOTensorDenseMatmul = "COOTensorDenseMatmul";
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
constexpr auto kCSRReduceSum = "CSRReduceSum";
constexpr auto kCSRMV = "CSRMV";
constexpr auto kCSRMM = "CSRMM";
constexpr auto kCSRMul = "CSRMul";
constexpr auto kCSRGather = "CSRGather";
constexpr auto kCSR2COO = "CSR2COO";
@ -754,6 +755,7 @@ GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>
GVAR_DEF(PrimitivePtr, kPrimCOOTensorDenseMatmul, std::make_shared<Primitive>(kCOOTensorDenseMatmul));
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
GVAR_DEF(PrimitivePtr, kPrimCSRMM, std::make_shared<Primitive>(kCSRMM));
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Providing akg directory base path"""
import importlib
import importlib.util
import os

View File

@ -28,6 +28,7 @@ from ...ops.composite.base import _append, _insert
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations._inner_ops import Format
from ...ops.operations import _csr_ops
from ...ops.primitive import constexpr
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
@ -42,6 +43,7 @@ _format = Format()
_reduce_sum_default = P.ReduceSum()
_reduce_sum_keepdims = P.ReduceSum(True)
_mean_keepdims = P.ReduceMean(True)
_csr_mm = _csr_ops.CSRMM()
itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2,
@ -2295,6 +2297,11 @@ def csr_mv(x, dense_vector):
return F.csr_mv(x, dense_vector)
def csr_mm(x, dense):
"""Implementation of `mm` for CSRTensor."""
return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense)
def csr_to_tuple(x):
"""Implementation of `to_tuple` for CSRTensor."""
res = (x.indptr, x.indices, x.values, x.shape)

View File

@ -4834,6 +4834,35 @@ class CSRTensor(CSRTensor_):
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
def mm(self, dense):
"""
Sparse matrix-matrix multiplication.
Args:
dense_vector (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1]
Returns:
Tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values = Tensor([2, 1], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
>>> Tensor([[1., 2.], [1, 2.], [1, 2.], [1., 2.]], dtype=mstype.float32)
>>> print(csr_tensor.mm(dense))
[[2., 4.]
[1., 2.]]
"""
validator.check_value_type('dense_matrix', dense, (Tensor_,), 'CSRTensor.mm')
return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values, self.shape, dense)
def sum(self, axis):
"""
Reduces a dimension of a CSRTensor by summing all elements in the dimension.

View File

@ -18,5 +18,6 @@ from .csr2coo import _csr2coo_akg
from .csr_gather import _csr_gather_akg
from .csr_mul import _csr_mul_akg
from .csr_mv import _csr_mv_akg
from .csr_mm import _csr_mm_akg
from .csr_reduce_sum import _csr_reduce_sum_akg
# Please insert op register in lexicographical order of the filename.

View File

@ -0,0 +1,37 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSRMM op"""
from mindspore.ops.op_info_register import op_info_register, AkgCpuRegOp, DataType
csr_mm_op_info = AkgCpuRegOp("CSRMM") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.input(1, "indices") \
.input(2, "values") \
.input(4, "dense_tensor") \
.output(0, "output") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.get_op_info()
@op_info_register(csr_mm_op_info)
def _csr_mm_akg():
"""CSRMM AutoDiff register"""
return

View File

@ -19,6 +19,7 @@ from .csr_gather import _csr_gather_akg
from .csr_div import _csr_div_akg
from .csr_mul import _csr_mul_akg
from .csr_mv import _csr_mv_akg
from .csr_mm import _csr_mm_akg
from .csr_reduce_sum import _csr_reduce_sum_akg
from .equal import _equal_akg
from .greater_equal import _greater_equal_akg

View File

@ -0,0 +1,37 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSRMM op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
csr_mm_op_info = AkgGpuRegOp("CSRMM") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.input(1, "indices") \
.input(2, "values") \
.input(4, "dense_tensor") \
.output(0, "output") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.get_op_info()
@op_info_register(csr_mm_op_info)
def _csr_mm_akg():
"""CSRMM AutoDiff register"""
return

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1015,6 +1015,7 @@ tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)
tensor_operator_registry.register('csr_div', csr_div)
tensor_operator_registry.register('csr_mv', csr_mv)
tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""csr_ops"""
from ..primitive import prim_attr_register, PrimitiveWithInfer
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
class CSRReduceSum(PrimitiveWithInfer):
@ -350,3 +350,54 @@ class CSRDiv(PrimitiveWithInfer):
"""Initialize CSRDiv"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
outputs=['output'])
class CSRMM(Primitive):
"""
Sparse matrix-vector multiplication.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **indptr** (Tensor) - A Tensor.
- **indices** (Tensor) - A Tensor.
- **values** (Tensor) - A Tensor.
- **shape** (tuple(int)) - A tuple.
- **dense_tensor** (Tensor) - A dense Tensor.
Outputs:
Tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.ops.operations import _csr_ops
>>> from mindspore import dtype as mstype
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = _csr_ops.CSRMM()
...
... def construct(self, indptr, indices, values, dense_shape, dense):
... return self.op( indptr, indices, values, dense_shape, dense)
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values = Tensor([2, 1], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> dense = Tensor([[1], [1], [1], [1]], dtype=mstype.float32)
>>> out = Net()(indptr, indices, values, dense_shape, dense)
>>> print(out)
[[2.]
[1.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize CSRMM"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense'],
outputs=['output'])

View File

@ -340,11 +340,13 @@ def test_csr_ops():
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
dense_matrix = Tensor([[1., 2.], [1, 2.], [1, 2.], [1., 2.]], dtype=mstype.float32)
def test_ops_pynative_dense():
dense1 = F.csr_reduce_sum(csr_tensor, 1)
dense2 = F.csr_mv(csr_tensor, dense_vector)
return dense1, dense2
dense3 = csr_tensor.mm(dense_matrix)
return dense1, dense2, dense3
def test_ops_pynative_sparse():
sparse1 = csr_tensor * dense_tensor
@ -359,10 +361,13 @@ def test_csr_ops():
graph_res_dense = test_ops_graph_dense()
expect1 = np.array([[2.], [1.]], dtype=np.float32)
expect2 = np.array([[2.], [1.]], dtype=np.float32)
expect3 = np.array([[2., 4.], [1., 2.]], dtype=np.float32)
assert np.allclose(pynative_res_dense[0].asnumpy(), expect1)
assert np.allclose(pynative_res_dense[1].asnumpy(), expect2)
assert np.allclose(pynative_res_dense[2].asnumpy(), expect3)
assert np.allclose(graph_res_dense[0].asnumpy(), expect1)
assert np.allclose(graph_res_dense[1].asnumpy(), expect2)
assert np.allclose(graph_res_dense[2].asnumpy(), expect3)
pynative_res_sparse = test_ops_pynative_sparse()
graph_res_sparse = test_ops_graph_sparse()

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -23,7 +23,6 @@ square = P.Square()
sqrt = P.Sqrt()
real_div = P.RealDiv()
sub = P.Sub()
Assign = P.Assign()
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')