register Tensor&functional&primitive info for operator SparseMatrixAdd

This commit is contained in:
mwang 2022-05-31 11:20:17 +08:00 committed by shibeiji
parent 460ecac5ff
commit 88649053f3
8 changed files with 368 additions and 2 deletions

View File

@ -285,6 +285,7 @@ BuiltInTypeMap &GetMethodMap() {
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
{"mm", std::string("csr_mm")}, // C.csr_mm
{"add", std::string("csr_add")}, // C.csr_add
}},
{kObjectTypeCOOTensorType,
{

View File

@ -2319,6 +2319,11 @@ def filter_(fun, iter_):
##################
def csr_add(a, b, alpha, beta):
"""Implementation of "csr_add" for CSRTensor."""
return F.csr_add(a, b, alpha, beta)
def csr_astype(x, dtype):
"""Implementation of `astype` for CSRTensor."""
data = x.values.astype(dtype)

View File

@ -5313,6 +5313,42 @@ class CSRTensor(CSRTensor_):
return CSRTensor(self.indptr, self.indices, data, self.shape)
def add(self, b, alpha, beta):
"""
Addition of two CSR Tensors : C = alpha * A + beta * B
Args:
b (CSRTensor): Sparse CSR Tensor.
alpha(Tensor): Dense Tensor, its shape must be able to broadcast to self.
beta(Tensor): Dense Tensor, its shape must be able to broadcast to b.
Returns:
CSRTensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> import mindspore.common.dtype as mstype
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values_a = Tensor([2, 1], dtype=mstype.float32)
>>> values_b = Tensor([1, 2], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> alpha = Tensor(1, mstype.float32)
>>> beta = Tensor(1, mstype.float32)
>>> a = CSRTensor(indptr, indices, values_a, dense_shape)
>>> b = CSRTensor(indptr, indices, values_b, dense_shape)
>>> print(a.add(b, alpha, beta))
CSRTensor(shape=[2,4], dtype=Float32,
indptr=Tensor(shape=[3], dtype=Int32, value = [0, 1, 2]),
indices=Tensor(shape=[2], dtype=Int32, value = [0, 1]),
values=Tensor(shape=[2], dtype=Float32, value = [3.0, 3.0]))
"""
return tensor_operator_registry.get('csr_add')(self, b, alpha, beta)
def _vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]

View File

@ -314,6 +314,7 @@ from .sparse_func import (
row_tensor_add,
sparse_add,
sparse_concat,
csr_add,
)
from .random_func import (
standard_laplace,

View File

@ -20,13 +20,14 @@ from ..operations.sparse_ops import (
DenseToCSRSparseMatrix,
CSRSparseMatrixToSparseTensor,
SparseConcat,
SparseAdd
SparseAdd,
SparseMatrixAdd
)
from ..operations.array_ops import GatherNd, Coalesce
from ..operations import _csr_ops
from ...common import CSRTensor, COOTensor, Tensor
from ...common import dtype as mstype
from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error
from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor
# utility functions and values
gather_nd = GatherNd()
@ -516,6 +517,64 @@ def sparse_add(x1, x2, thresh):
return COOTensor(indices, values, x1.shape)
def csr_add(a, b, alpha, beta):
"""
Returns alpha * csr_a + beta * csr_b where both csr_a and csr_b are CSRTensor, alpha and beta are both Tensor.
Args:
a (CSRTensor): Sparse CSR Tensor.
b (CSRTensor): Sparse CSR Tensor.
alpha(Tensor): Dense Tensor, its shape must be able to broadcast to a.
beta(Tensor): Dense Tensor, its shape must be able to broadcast to b.
Returns:
CSRTensor. a csr_tensor containing:
indptr: indicates the start and end point for `values` in each row.
indices: the column positions of all non-zero values of the input.
values: the non-zero values of the dense tensor.
shape: the shape of the csr_tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore.ops.functional import csr_add
>>> a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> a_indices = Tensor([0, 1], dtype=mstype.int32)
>>> a_values = Tensor([1, 2], dtype=mstype.float32)
>>> shape = (2, 6)
>>> b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> b_indices = Tensor([0, 1], dtype=mstype.int32)
>>> b_values = Tensor([1, 2], dtype=mstype.float32)
>>> alpha = Tensor(1, mstype.float32)
>>> beta = Tensor(1, mstype.float32)
>>> csra = CSRTensor(a_indptr, a_indices, a_values, shape)
>>> csrb = CSRTensor(b_indptr, b_indices, b_values, shape)
>>> out = csr_add(csra, csrb, alpha, beta)
>>> print(out)
CSRTensor(shape=[2,6], dtype=Float32,
indptr=Tensor(shape=[3], dtype=Int32, value = [0, 1, 2]),
indices=Tensor(shape=[2], dtype=Int32, value = [0, 1]),
values=Tensor(shape=[2], dtype=Float32, value = [2.0, 4.0]))
"""
if not isinstance(a, CSRTensor) or not isinstance(b, CSRTensor):
raise_type_error("For functional operator csr_add, both inputs a and b must be type of CSRTensor.")
if not isinstance(alpha, Tensor) or not isinstance(beta, Tensor):
raise_type_error("For functional operator csr_add, both inputs alpha and beta must be Tensor.")
csr_add_op = SparseMatrixAdd()
a_batch_pointers = make_tensor([0, a.values.shape[0]], dtype=mstype.int32)
b_batch_pointers = make_tensor([0, b.values.shape[0]], dtype=mstype.int32)
a_shape = make_tensor(a.shape, dtype=mstype.int32)
b_shape = make_tensor(b.shape, dtype=mstype.int32)
shape, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
alpha, beta)
output_shape = tuple(shape.asnumpy().tolist())
return CSRTensor(indptr=indptr, indices=indices, values=values, shape=output_shape)
__all__ = [
'coalesce',
'coo2csr',
@ -545,6 +604,7 @@ __all__ = [
'row_tensor_add',
'sparse_add',
'sparse_concat',
'csr_add'
]
__all__.sort()

View File

@ -889,6 +889,7 @@ tensor_operator_registry.register('log', log)
tensor_operator_registry.register('lerp', lerp)
tensor_operator_registry.register('floor', floor)
# support sparse tensor operators
tensor_operator_registry.register('csr_add', csr_add)
tensor_operator_registry.register('csr_mul', csr_mul)
tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)

View File

@ -1194,3 +1194,81 @@ class SparseMatrixMatMul(Primitive):
validator.check_value_type("conjugate_output", conjugate_output, [bool], self.name)
self.init_prim_io_names(inputs=['x1_dense_shape', 'x1_batch_pointers', 'x1_row_pointers',
'x1_col_indices', 'x1_values', 'x2_dense'], outputs=['y_dense'])
class SparseMatrixAdd(Primitive):
"""
Addition of two CSR Tensors : C = alpha * A + beta * B
Inputs:
- **x1_dense_shape** (Tensor) - A 1-D Tensor represents the dense form shape of the input CSR sparse matrix.
- **x1_batch_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
batch size `n`, it should have shape :math:`(n+1,)`, while the `i`-th element of which stores
acummulated counts of non-zero values of the first `i - 1` batches.
- **x1_row_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
batch size `n` and row number `m`, it can be divided into `n` parts, each part of length
`m + 1`. The `i`-th element of each :math:`(m+1,)` vector stores acummulated counts of
non-zero values of the first `i - 1` rows in the corresponding batch.
- **x1_col_indices** (Tensor) - A 1-D Tensor. It represents column indices of the non-zero values
in the input CSR sparse matrix.
- **x1_values** (Tensor) - A 1-D Tensor. It represents all the non-zero values in the input CSR sparse matrix.
- **x2_dense_shape** (Tensor) - A Tensor, same meaning as x1_dense_shape.
- **x2_batch_pointers** (Tensor) - A Tensor, same meaning as x1_batch_pointers.
- **x2_row_pointers** (Tensor) - A Tensor, same meaning as x1_row_pointers.
- **x2_col_indices** (Tensor) - A Tensor, same meaning as x1_col_indices.
- **x2_values** (Tensor) - A Tensor, same meaning as x1_values.
- **alpha** (Tensor) - A Tensor.
- **beta** (Tensor) - A Tensor.
Outputs:
- **y1_dense_shape** (Tensor) - A Tensor.
- **y1_batch_pointers** (Tensor) - A Tensor.
- **y1_row_pointers** (Tensor) - A Tensor.
- **y1_col_indices** (Tensor) - A Tensor.
- **y1_values** (Tensor) - A Tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.common.dtype as mstype
>>> from mindspore import Tensor
>>> from mindspore.ops.operations.sparse_ops import SparseMatrixAdd
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = SparseMatrixAdd()
...
... def construct(self, a_shape, a_batch_pointer, a_indptr, a_indices, a_values,
... b_shape, b_batch_pointer, b_indptr, b_indices, b_values, alpha, beta):
... return self.op(a_shape, a_batch_pointer, a_indptr, a_indices, a_values,
... b_shape, b_batch_pointer, b_indptr, b_indices, b_values, alpha, beta)
>>> a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> a_indices = Tensor([0, 1], dtype=mstype.int32)
>>> a_values = Tensor([1, 2], dtype=mstype.float32)
>>> a_pointers = Tensor([0, a_values.shape[0]], dtype=mstype.int32)
>>> shape = Tensor([2, 6], dtype=mstype.int32)
>>> b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> b_indices = Tensor([0, 1], dtype=mstype.int32)
>>> b_values = Tensor([1, 2], dtype=mstype.float32)
>>> b_pointers = Tensor([0, b_values.shape[0]], dtype=mstype.int32)
>>> alpha = Tensor(1, mstype.float32)
>>> beta = Tensor(1, mstype.float32)
>>> out = Net()(shape, a_pointers, a_indptr, a_indices, a_values,
... shape, b_pointers, b_indptr, b_indices, b_values, alpha, beta)
>>> print(out)
(Tensor(shape=[2], dtype=Int32, value =[2, 6]),
Tensor(shape[2], dtype=Int32, value = [0, 2]),
Tensor(shape=[3], dtype=Int32, values = [0, 1, 2]),
Tensor(shape=[2], dtype=Int32, values = [0, 1]),
Tensor(shape=[2], dtype=Float32, values = [2.0, 4.0]))
"""
@prim_attr_register
def __init__(self):
'''Initialize for SparseMatrixAdd'''
self.init_prim_io_names(inputs=['x1_dense_shape', 'x1_batch_pointers', 'x1_row_pointers', 'x1_col_indices',
'x1_values', 'x2_dense_shape', 'x2_batch_pointers', 'x2_row_pointers',
'x2_col_indices', 'x2_values', 'alpha', 'beta'],
outputs=['y_dense_shape', 'y_batch_pointers', 'y_row_pointers', 'y_col_indices',
'y_values'])

View File

@ -0,0 +1,184 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""smoke tests for SparseMatrixAdd"""
import pytest
import numpy as np
from mindspore import Tensor, CSRTensor, context
from mindspore.ops.function import csr_add
from mindspore.ops.operations.sparse_ops import SparseMatrixAdd
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr
@constexpr
def _make_tensor(a, dtype=mstype.int64):
"""Converts the input to tensor."""
if not isinstance(a, (list, tuple, int, float, bool)):
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
if isinstance(a, (list, tuple)):
a = np.asarray(a)
if a.dtype is np.dtype('object'):
raise ValueError('Input array must have the same size across all dimensions.')
return Tensor(a, dtype)
def create_csr_tensor():
a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
a_indices = Tensor([0, 1], dtype=mstype.int32)
a_values = Tensor([1, 2], dtype=mstype.float32)
shape = (2, 6)
b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
b_indices = Tensor([0, 1], dtype=mstype.int32)
b_values = Tensor([1, 2], dtype=mstype.float32)
csra = CSRTensor(a_indptr, a_indices, a_values, shape)
csrb = CSRTensor(b_indptr, b_indices, b_values, shape)
return csra, csrb
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_function_csr_add():
"""
Feature: Test function csr_add.
Description: Test CSRTensor matrix add.
Expectation: Success.
"""
context.set_context(mode=context.GRAPH_MODE)
alpha = Tensor(1, mstype.float32)
beta = Tensor(1, mstype.float32)
csra, csrb = create_csr_tensor()
c = csr_add(csra, csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [2.0, 4.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)
beta = Tensor(-1, mstype.float32)
c = csr_add(csra, csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [0.0, 0.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_graph_csr_add():
"""
Feature: Test ops SparseMatrixAdd.
Description: Test CSRTensor matrix add.
Expectation: Success.
"""
context.set_context(mode=context.GRAPH_MODE)
alpha = Tensor(1, mstype.float32)
beta = Tensor(1, mstype.float32)
csra, csrb = create_csr_tensor()
csra_pointers = _make_tensor([0, csra.values.shape[0]], mstype.int32)
csrb_pointers = _make_tensor([0, csrb.values.shape[0]], mstype.int32)
csra_shape = _make_tensor(csra.shape, mstype.int32)
csrb_shape = _make_tensor(csrb.shape, mstype.int32)
csr_add_op = SparseMatrixAdd()
c = csr_add_op(csra_shape, csra_pointers, csra.indptr, csra.indices, csra.values,
csrb_shape, csrb_pointers, csrb.indptr, csrb.indices, csrb.values, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [2.0, 4.0]
assert np.allclose(c[2].asnumpy(), c_indptr_expected)
assert np.allclose(c[3].asnumpy(), c_indices_expected)
assert np.allclose(c[4].asnumpy(), c_values_excpected)
beta = Tensor(-1, mstype.float32)
c = csr_add_op(csra_shape, csra_pointers, csra.indptr, csra.indices, csra.values,
csrb_shape, csrb_pointers, csrb.indptr, csrb.indices, csrb.values, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [0.0, 0.0]
assert np.allclose(c[2].asnumpy(), c_indptr_expected)
assert np.allclose(c[3].asnumpy(), c_indices_expected)
assert np.allclose(c[4].asnumpy(), c_values_excpected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_csr_add():
"""
Feature: Test Tensor csr_add.
Description: Test CSRTensor matrix add.
Expectation: Success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
alpha = Tensor(1, mstype.float32)
beta = Tensor(1, mstype.float32)
csra, csrb = create_csr_tensor()
c = csra.add(csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [2.0, 4.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)
beta = Tensor(-1, mstype.float32)
c = csra.add(csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [0.0, 0.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tensor_csr_add():
"""
Feature: Test Tensor csr_add.
Description: Test CSRTensor matrix add.
Expectation: Success.
"""
context.set_context(mode=context.GRAPH_MODE)
alpha = Tensor(1, mstype.float32)
beta = Tensor(1, mstype.float32)
csra, csrb = create_csr_tensor()
c = csra.add(csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [2.0, 4.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)
beta = Tensor(-1, mstype.float32)
c = csra.add(csrb, alpha, beta)
c_indptr_expected = [0, 1, 2]
c_indices_expected = [0, 1]
c_values_excpected = [0.0, 0.0]
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
assert np.allclose(c.values.asnumpy(), c_values_excpected)