register Tensor&functional&primitive info for operator SparseMatrixAdd
This commit is contained in:
parent
460ecac5ff
commit
88649053f3
|
@ -285,6 +285,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
|
||||
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
|
||||
{"mm", std::string("csr_mm")}, // C.csr_mm
|
||||
{"add", std::string("csr_add")}, // C.csr_add
|
||||
}},
|
||||
{kObjectTypeCOOTensorType,
|
||||
{
|
||||
|
|
|
@ -2319,6 +2319,11 @@ def filter_(fun, iter_):
|
|||
##################
|
||||
|
||||
|
||||
def csr_add(a, b, alpha, beta):
|
||||
"""Implementation of "csr_add" for CSRTensor."""
|
||||
return F.csr_add(a, b, alpha, beta)
|
||||
|
||||
|
||||
def csr_astype(x, dtype):
|
||||
"""Implementation of `astype` for CSRTensor."""
|
||||
data = x.values.astype(dtype)
|
||||
|
|
|
@ -5313,6 +5313,42 @@ class CSRTensor(CSRTensor_):
|
|||
return CSRTensor(self.indptr, self.indices, data, self.shape)
|
||||
|
||||
|
||||
def add(self, b, alpha, beta):
|
||||
"""
|
||||
Addition of two CSR Tensors : C = alpha * A + beta * B
|
||||
|
||||
Args:
|
||||
b (CSRTensor): Sparse CSR Tensor.
|
||||
alpha(Tensor): Dense Tensor, its shape must be able to broadcast to self.
|
||||
beta(Tensor): Dense Tensor, its shape must be able to broadcast to b.
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> values_a = Tensor([2, 1], dtype=mstype.float32)
|
||||
>>> values_b = Tensor([1, 2], dtype=mstype.float32)
|
||||
>>> dense_shape = (2, 4)
|
||||
>>> alpha = Tensor(1, mstype.float32)
|
||||
>>> beta = Tensor(1, mstype.float32)
|
||||
>>> a = CSRTensor(indptr, indices, values_a, dense_shape)
|
||||
>>> b = CSRTensor(indptr, indices, values_b, dense_shape)
|
||||
>>> print(a.add(b, alpha, beta))
|
||||
CSRTensor(shape=[2,4], dtype=Float32,
|
||||
indptr=Tensor(shape=[3], dtype=Int32, value = [0, 1, 2]),
|
||||
indices=Tensor(shape=[2], dtype=Int32, value = [0, 1]),
|
||||
values=Tensor(shape=[2], dtype=Float32, value = [3.0, 3.0]))
|
||||
"""
|
||||
return tensor_operator_registry.get('csr_add')(self, b, alpha, beta)
|
||||
|
||||
|
||||
def _vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
|
|
|
@ -314,6 +314,7 @@ from .sparse_func import (
|
|||
row_tensor_add,
|
||||
sparse_add,
|
||||
sparse_concat,
|
||||
csr_add,
|
||||
)
|
||||
from .random_func import (
|
||||
standard_laplace,
|
||||
|
|
|
@ -20,13 +20,14 @@ from ..operations.sparse_ops import (
|
|||
DenseToCSRSparseMatrix,
|
||||
CSRSparseMatrixToSparseTensor,
|
||||
SparseConcat,
|
||||
SparseAdd
|
||||
SparseAdd,
|
||||
SparseMatrixAdd
|
||||
)
|
||||
from ..operations.array_ops import GatherNd, Coalesce
|
||||
from ..operations import _csr_ops
|
||||
from ...common import CSRTensor, COOTensor, Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error
|
||||
from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor
|
||||
|
||||
# utility functions and values
|
||||
gather_nd = GatherNd()
|
||||
|
@ -516,6 +517,64 @@ def sparse_add(x1, x2, thresh):
|
|||
return COOTensor(indices, values, x1.shape)
|
||||
|
||||
|
||||
def csr_add(a, b, alpha, beta):
|
||||
"""
|
||||
Returns alpha * csr_a + beta * csr_b where both csr_a and csr_b are CSRTensor, alpha and beta are both Tensor.
|
||||
|
||||
Args:
|
||||
a (CSRTensor): Sparse CSR Tensor.
|
||||
b (CSRTensor): Sparse CSR Tensor.
|
||||
alpha(Tensor): Dense Tensor, its shape must be able to broadcast to a.
|
||||
beta(Tensor): Dense Tensor, its shape must be able to broadcast to b.
|
||||
|
||||
Returns:
|
||||
CSRTensor. a csr_tensor containing:
|
||||
indptr: indicates the start and end point for `values` in each row.
|
||||
indices: the column positions of all non-zero values of the input.
|
||||
values: the non-zero values of the dense tensor.
|
||||
shape: the shape of the csr_tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore.ops.functional import csr_add
|
||||
>>> a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> a_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> a_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
>>> shape = (2, 6)
|
||||
>>> b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> b_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> b_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
>>> alpha = Tensor(1, mstype.float32)
|
||||
>>> beta = Tensor(1, mstype.float32)
|
||||
>>> csra = CSRTensor(a_indptr, a_indices, a_values, shape)
|
||||
>>> csrb = CSRTensor(b_indptr, b_indices, b_values, shape)
|
||||
>>> out = csr_add(csra, csrb, alpha, beta)
|
||||
>>> print(out)
|
||||
CSRTensor(shape=[2,6], dtype=Float32,
|
||||
indptr=Tensor(shape=[3], dtype=Int32, value = [0, 1, 2]),
|
||||
indices=Tensor(shape=[2], dtype=Int32, value = [0, 1]),
|
||||
values=Tensor(shape=[2], dtype=Float32, value = [2.0, 4.0]))
|
||||
"""
|
||||
if not isinstance(a, CSRTensor) or not isinstance(b, CSRTensor):
|
||||
raise_type_error("For functional operator csr_add, both inputs a and b must be type of CSRTensor.")
|
||||
if not isinstance(alpha, Tensor) or not isinstance(beta, Tensor):
|
||||
raise_type_error("For functional operator csr_add, both inputs alpha and beta must be Tensor.")
|
||||
csr_add_op = SparseMatrixAdd()
|
||||
a_batch_pointers = make_tensor([0, a.values.shape[0]], dtype=mstype.int32)
|
||||
b_batch_pointers = make_tensor([0, b.values.shape[0]], dtype=mstype.int32)
|
||||
a_shape = make_tensor(a.shape, dtype=mstype.int32)
|
||||
b_shape = make_tensor(b.shape, dtype=mstype.int32)
|
||||
shape, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
|
||||
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
|
||||
alpha, beta)
|
||||
output_shape = tuple(shape.asnumpy().tolist())
|
||||
return CSRTensor(indptr=indptr, indices=indices, values=values, shape=output_shape)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'coalesce',
|
||||
'coo2csr',
|
||||
|
@ -545,6 +604,7 @@ __all__ = [
|
|||
'row_tensor_add',
|
||||
'sparse_add',
|
||||
'sparse_concat',
|
||||
'csr_add'
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -889,6 +889,7 @@ tensor_operator_registry.register('log', log)
|
|||
tensor_operator_registry.register('lerp', lerp)
|
||||
tensor_operator_registry.register('floor', floor)
|
||||
# support sparse tensor operators
|
||||
tensor_operator_registry.register('csr_add', csr_add)
|
||||
tensor_operator_registry.register('csr_mul', csr_mul)
|
||||
tensor_operator_registry.register('csr2coo', csr2coo)
|
||||
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||
|
|
|
@ -1194,3 +1194,81 @@ class SparseMatrixMatMul(Primitive):
|
|||
validator.check_value_type("conjugate_output", conjugate_output, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x1_dense_shape', 'x1_batch_pointers', 'x1_row_pointers',
|
||||
'x1_col_indices', 'x1_values', 'x2_dense'], outputs=['y_dense'])
|
||||
|
||||
|
||||
class SparseMatrixAdd(Primitive):
|
||||
"""
|
||||
Addition of two CSR Tensors : C = alpha * A + beta * B
|
||||
|
||||
Inputs:
|
||||
- **x1_dense_shape** (Tensor) - A 1-D Tensor represents the dense form shape of the input CSR sparse matrix.
|
||||
- **x1_batch_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
|
||||
batch size `n`, it should have shape :math:`(n+1,)`, while the `i`-th element of which stores
|
||||
acummulated counts of non-zero values of the first `i - 1` batches.
|
||||
- **x1_row_pointers** (Tensor) - A 1-D Tensor. Supposing the input CSR sparse matrix is of
|
||||
batch size `n` and row number `m`, it can be divided into `n` parts, each part of length
|
||||
`m + 1`. The `i`-th element of each :math:`(m+1,)` vector stores acummulated counts of
|
||||
non-zero values of the first `i - 1` rows in the corresponding batch.
|
||||
- **x1_col_indices** (Tensor) - A 1-D Tensor. It represents column indices of the non-zero values
|
||||
in the input CSR sparse matrix.
|
||||
- **x1_values** (Tensor) - A 1-D Tensor. It represents all the non-zero values in the input CSR sparse matrix.
|
||||
- **x2_dense_shape** (Tensor) - A Tensor, same meaning as x1_dense_shape.
|
||||
- **x2_batch_pointers** (Tensor) - A Tensor, same meaning as x1_batch_pointers.
|
||||
- **x2_row_pointers** (Tensor) - A Tensor, same meaning as x1_row_pointers.
|
||||
- **x2_col_indices** (Tensor) - A Tensor, same meaning as x1_col_indices.
|
||||
- **x2_values** (Tensor) - A Tensor, same meaning as x1_values.
|
||||
- **alpha** (Tensor) - A Tensor.
|
||||
- **beta** (Tensor) - A Tensor.
|
||||
|
||||
Outputs:
|
||||
- **y1_dense_shape** (Tensor) - A Tensor.
|
||||
- **y1_batch_pointers** (Tensor) - A Tensor.
|
||||
- **y1_row_pointers** (Tensor) - A Tensor.
|
||||
- **y1_col_indices** (Tensor) - A Tensor.
|
||||
- **y1_values** (Tensor) - A Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations.sparse_ops import SparseMatrixAdd
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.op = SparseMatrixAdd()
|
||||
...
|
||||
... def construct(self, a_shape, a_batch_pointer, a_indptr, a_indices, a_values,
|
||||
... b_shape, b_batch_pointer, b_indptr, b_indices, b_values, alpha, beta):
|
||||
... return self.op(a_shape, a_batch_pointer, a_indptr, a_indices, a_values,
|
||||
... b_shape, b_batch_pointer, b_indptr, b_indices, b_values, alpha, beta)
|
||||
>>> a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> a_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> a_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
>>> a_pointers = Tensor([0, a_values.shape[0]], dtype=mstype.int32)
|
||||
>>> shape = Tensor([2, 6], dtype=mstype.int32)
|
||||
>>> b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> b_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> b_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
>>> b_pointers = Tensor([0, b_values.shape[0]], dtype=mstype.int32)
|
||||
>>> alpha = Tensor(1, mstype.float32)
|
||||
>>> beta = Tensor(1, mstype.float32)
|
||||
>>> out = Net()(shape, a_pointers, a_indptr, a_indices, a_values,
|
||||
... shape, b_pointers, b_indptr, b_indices, b_values, alpha, beta)
|
||||
>>> print(out)
|
||||
(Tensor(shape=[2], dtype=Int32, value =[2, 6]),
|
||||
Tensor(shape[2], dtype=Int32, value = [0, 2]),
|
||||
Tensor(shape=[3], dtype=Int32, values = [0, 1, 2]),
|
||||
Tensor(shape=[2], dtype=Int32, values = [0, 1]),
|
||||
Tensor(shape=[2], dtype=Float32, values = [2.0, 4.0]))
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
'''Initialize for SparseMatrixAdd'''
|
||||
self.init_prim_io_names(inputs=['x1_dense_shape', 'x1_batch_pointers', 'x1_row_pointers', 'x1_col_indices',
|
||||
'x1_values', 'x2_dense_shape', 'x2_batch_pointers', 'x2_row_pointers',
|
||||
'x2_col_indices', 'x2_values', 'alpha', 'beta'],
|
||||
outputs=['y_dense_shape', 'y_batch_pointers', 'y_row_pointers', 'y_col_indices',
|
||||
'y_values'])
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""smoke tests for SparseMatrixAdd"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, CSRTensor, context
|
||||
from mindspore.ops.function import csr_add
|
||||
from mindspore.ops.operations.sparse_ops import SparseMatrixAdd
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
@constexpr
|
||||
def _make_tensor(a, dtype=mstype.int64):
|
||||
"""Converts the input to tensor."""
|
||||
if not isinstance(a, (list, tuple, int, float, bool)):
|
||||
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
|
||||
if isinstance(a, (list, tuple)):
|
||||
a = np.asarray(a)
|
||||
if a.dtype is np.dtype('object'):
|
||||
raise ValueError('Input array must have the same size across all dimensions.')
|
||||
return Tensor(a, dtype)
|
||||
|
||||
|
||||
def create_csr_tensor():
|
||||
a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
a_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
a_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (2, 6)
|
||||
b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
b_indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
b_values = Tensor([1, 2], dtype=mstype.float32)
|
||||
csra = CSRTensor(a_indptr, a_indices, a_values, shape)
|
||||
csrb = CSRTensor(b_indptr, b_indices, b_values, shape)
|
||||
return csra, csrb
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_function_csr_add():
|
||||
"""
|
||||
Feature: Test function csr_add.
|
||||
Description: Test CSRTensor matrix add.
|
||||
Expectation: Success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
alpha = Tensor(1, mstype.float32)
|
||||
beta = Tensor(1, mstype.float32)
|
||||
csra, csrb = create_csr_tensor()
|
||||
c = csr_add(csra, csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [2.0, 4.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
||||
|
||||
beta = Tensor(-1, mstype.float32)
|
||||
c = csr_add(csra, csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [0.0, 0.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_csr_add():
|
||||
"""
|
||||
Feature: Test ops SparseMatrixAdd.
|
||||
Description: Test CSRTensor matrix add.
|
||||
Expectation: Success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
alpha = Tensor(1, mstype.float32)
|
||||
beta = Tensor(1, mstype.float32)
|
||||
csra, csrb = create_csr_tensor()
|
||||
csra_pointers = _make_tensor([0, csra.values.shape[0]], mstype.int32)
|
||||
csrb_pointers = _make_tensor([0, csrb.values.shape[0]], mstype.int32)
|
||||
csra_shape = _make_tensor(csra.shape, mstype.int32)
|
||||
csrb_shape = _make_tensor(csrb.shape, mstype.int32)
|
||||
csr_add_op = SparseMatrixAdd()
|
||||
c = csr_add_op(csra_shape, csra_pointers, csra.indptr, csra.indices, csra.values,
|
||||
csrb_shape, csrb_pointers, csrb.indptr, csrb.indices, csrb.values, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [2.0, 4.0]
|
||||
assert np.allclose(c[2].asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c[3].asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c[4].asnumpy(), c_values_excpected)
|
||||
|
||||
beta = Tensor(-1, mstype.float32)
|
||||
c = csr_add_op(csra_shape, csra_pointers, csra.indptr, csra.indices, csra.values,
|
||||
csrb_shape, csrb_pointers, csrb.indptr, csrb.indices, csrb.values, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [0.0, 0.0]
|
||||
assert np.allclose(c[2].asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c[3].asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c[4].asnumpy(), c_values_excpected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_csr_add():
|
||||
"""
|
||||
Feature: Test Tensor csr_add.
|
||||
Description: Test CSRTensor matrix add.
|
||||
Expectation: Success.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
alpha = Tensor(1, mstype.float32)
|
||||
beta = Tensor(1, mstype.float32)
|
||||
csra, csrb = create_csr_tensor()
|
||||
c = csra.add(csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [2.0, 4.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
||||
|
||||
beta = Tensor(-1, mstype.float32)
|
||||
c = csra.add(csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [0.0, 0.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_csr_add():
|
||||
"""
|
||||
Feature: Test Tensor csr_add.
|
||||
Description: Test CSRTensor matrix add.
|
||||
Expectation: Success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
alpha = Tensor(1, mstype.float32)
|
||||
beta = Tensor(1, mstype.float32)
|
||||
csra, csrb = create_csr_tensor()
|
||||
c = csra.add(csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [2.0, 4.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
||||
|
||||
beta = Tensor(-1, mstype.float32)
|
||||
c = csra.add(csrb, alpha, beta)
|
||||
c_indptr_expected = [0, 1, 2]
|
||||
c_indices_expected = [0, 1]
|
||||
c_values_excpected = [0.0, 0.0]
|
||||
assert np.allclose(c.indptr.asnumpy(), c_indptr_expected)
|
||||
assert np.allclose(c.indices.asnumpy(), c_indices_expected)
|
||||
assert np.allclose(c.values.asnumpy(), c_values_excpected)
|
Loading…
Reference in New Issue