support spmm
This commit is contained in:
parent
415241381b
commit
27e96dd46d
|
@ -83,6 +83,19 @@ mindspore.CSRTensor
|
|||
|
||||
Tensor。
|
||||
|
||||
.. py:method:: mm(dense_matrix)
|
||||
|
||||
返回CSRTensor右乘稠密矩阵的矩阵乘法运算结果。
|
||||
形状为 `[M, N]` 的CSRTensor,需要适配形状为 `[N, K]` 的稠密矩阵,得到结果为 `[M, K]` 的稠密矩阵。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **dense_matrix** (Tensor) - 形状为 `[N,K]` 的二维矩阵,其中N等于CSRTensor的列数。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor。
|
||||
|
||||
.. py:method:: ndim
|
||||
:property:
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||
Register(prim::kPrimCSRReduceSum->name(), {3, 4});
|
||||
Register(prim::kPrimCSRMV->name(), {3});
|
||||
Register(prim::kPrimCSRMM->name(), {3});
|
||||
Register(prim::kPrimCSRMul->name(), {3});
|
||||
Register(prim::kPrimCSRDiv->name(), {3});
|
||||
Register(prim::kPrimCSRGather->name(), {3});
|
||||
|
|
|
@ -277,6 +277,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
|
||||
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
|
||||
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
|
||||
{"mm", std::string("csr_mm")}, // C.csr_mm
|
||||
}},
|
||||
{kObjectTypeCOOTensorType,
|
||||
{
|
||||
|
|
|
@ -344,8 +344,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
|||
|
||||
bool IsAKGSparseOP(const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
|
||||
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv};
|
||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
|
||||
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
|
||||
return IsOneOfPrimitiveCNode(cnode, prims);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -161,6 +161,8 @@ AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const Primiti
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1033,5 +1033,48 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
|
|||
AbstractBasePtrList rets = {add1, add0, sub0};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRMMInputsNum = 5;
|
||||
constexpr auto kCSRMMShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRMMInputsNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
|
||||
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
|
||||
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
|
||||
<< "and csr tensor's shape[1] is" << sparse_shape[kIndexOne];
|
||||
}
|
||||
|
||||
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -286,6 +286,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
||||
{prim::kPrimCSRMM, R{InferImplCSRMM, nullptr, true}},
|
||||
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
||||
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
|
||||
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
|
||||
|
|
|
@ -219,6 +219,7 @@ constexpr auto kCOOTensorDenseMatmul = "COOTensorDenseMatmul";
|
|||
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
|
||||
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
||||
constexpr auto kCSRMV = "CSRMV";
|
||||
constexpr auto kCSRMM = "CSRMM";
|
||||
constexpr auto kCSRMul = "CSRMul";
|
||||
constexpr auto kCSRGather = "CSRGather";
|
||||
constexpr auto kCSR2COO = "CSR2COO";
|
||||
|
@ -754,6 +755,7 @@ GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>
|
|||
GVAR_DEF(PrimitivePtr, kPrimCOOTensorDenseMatmul, std::make_shared<Primitive>(kCOOTensorDenseMatmul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMM, std::make_shared<Primitive>(kCSRMM));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing akg directory base path"""
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from ...ops.composite.base import _append, _insert
|
|||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations._inner_ops import Format
|
||||
from ...ops.operations import _csr_ops
|
||||
from ...ops.primitive import constexpr
|
||||
|
||||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
||||
|
@ -42,6 +43,7 @@ _format = Format()
|
|||
_reduce_sum_default = P.ReduceSum()
|
||||
_reduce_sum_keepdims = P.ReduceSum(True)
|
||||
_mean_keepdims = P.ReduceMean(True)
|
||||
_csr_mm = _csr_ops.CSRMM()
|
||||
|
||||
itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
|
||||
mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2,
|
||||
|
@ -2295,6 +2297,11 @@ def csr_mv(x, dense_vector):
|
|||
return F.csr_mv(x, dense_vector)
|
||||
|
||||
|
||||
def csr_mm(x, dense):
|
||||
"""Implementation of `mm` for CSRTensor."""
|
||||
return _csr_mm(x.indptr, x.indices, x.values, x.shape, dense)
|
||||
|
||||
|
||||
def csr_to_tuple(x):
|
||||
"""Implementation of `to_tuple` for CSRTensor."""
|
||||
res = (x.indptr, x.indices, x.values, x.shape)
|
||||
|
|
|
@ -4834,6 +4834,35 @@ class CSRTensor(CSRTensor_):
|
|||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
||||
|
||||
def mm(self, dense):
|
||||
"""
|
||||
Sparse matrix-matrix multiplication.
|
||||
|
||||
Args:
|
||||
dense_vector (Tensor): A dense Tensor, its shape[0] should be equal to csr_tensor.shape[1]
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> values = Tensor([2, 1], dtype=mstype.float32)
|
||||
>>> dense_shape = (2, 4)
|
||||
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
>>> Tensor([[1., 2.], [1, 2.], [1, 2.], [1., 2.]], dtype=mstype.float32)
|
||||
>>> print(csr_tensor.mm(dense))
|
||||
[[2., 4.]
|
||||
[1., 2.]]
|
||||
"""
|
||||
validator.check_value_type('dense_matrix', dense, (Tensor_,), 'CSRTensor.mm')
|
||||
return tensor_operator_registry.get("csr_mm")()(self.indptr, self.indices, self.values, self.shape, dense)
|
||||
|
||||
def sum(self, axis):
|
||||
"""
|
||||
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
||||
|
|
|
@ -18,5 +18,6 @@ from .csr2coo import _csr2coo_akg
|
|||
from .csr_gather import _csr_gather_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_mm import _csr_mm_akg
|
||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
# Please insert op register in lexicographical order of the filename.
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRMM op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgCpuRegOp, DataType
|
||||
|
||||
csr_mm_op_info = AkgCpuRegOp("CSRMM") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "values") \
|
||||
.input(4, "dense_tensor") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(csr_mm_op_info)
|
||||
def _csr_mm_akg():
|
||||
"""CSRMM AutoDiff register"""
|
||||
return
|
|
@ -19,6 +19,7 @@ from .csr_gather import _csr_gather_akg
|
|||
from .csr_div import _csr_div_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_mm import _csr_mm_akg
|
||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
from .equal import _equal_akg
|
||||
from .greater_equal import _greater_equal_akg
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRMM op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
csr_mm_op_info = AkgGpuRegOp("CSRMM") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "values") \
|
||||
.input(4, "dense_tensor") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(csr_mm_op_info)
|
||||
def _csr_mm_akg():
|
||||
"""CSRMM AutoDiff register"""
|
||||
return
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1015,6 +1015,7 @@ tensor_operator_registry.register('csr2coo', csr2coo)
|
|||
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||
tensor_operator_registry.register('csr_div', csr_div)
|
||||
tensor_operator_registry.register('csr_mv', csr_mv)
|
||||
tensor_operator_registry.register('csr_mm', _csr_ops.CSRMM)
|
||||
tensor_operator_registry.register('csr_reduce_sum', csr_reduce_sum)
|
||||
tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
|
||||
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""csr_ops"""
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
|
||||
|
||||
class CSRReduceSum(PrimitiveWithInfer):
|
||||
|
@ -350,3 +350,54 @@ class CSRDiv(PrimitiveWithInfer):
|
|||
"""Initialize CSRDiv"""
|
||||
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
|
||||
outputs=['output'])
|
||||
|
||||
|
||||
class CSRMM(Primitive):
|
||||
"""
|
||||
Sparse matrix-vector multiplication.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **indptr** (Tensor) - A Tensor.
|
||||
- **indices** (Tensor) - A Tensor.
|
||||
- **values** (Tensor) - A Tensor.
|
||||
- **shape** (tuple(int)) - A tuple.
|
||||
- **dense_tensor** (Tensor) - A dense Tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations import _csr_ops
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.op = _csr_ops.CSRMM()
|
||||
...
|
||||
... def construct(self, indptr, indices, values, dense_shape, dense):
|
||||
... return self.op( indptr, indices, values, dense_shape, dense)
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
>>> values = Tensor([2, 1], dtype=mstype.float32)
|
||||
>>> dense_shape = (2, 4)
|
||||
>>> dense = Tensor([[1], [1], [1], [1]], dtype=mstype.float32)
|
||||
>>> out = Net()(indptr, indices, values, dense_shape, dense)
|
||||
>>> print(out)
|
||||
[[2.]
|
||||
[1.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize CSRMM"""
|
||||
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense'],
|
||||
outputs=['output'])
|
||||
|
|
|
@ -340,11 +340,13 @@ def test_csr_ops():
|
|||
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||
dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
dense_matrix = Tensor([[1., 2.], [1, 2.], [1, 2.], [1., 2.]], dtype=mstype.float32)
|
||||
|
||||
def test_ops_pynative_dense():
|
||||
dense1 = F.csr_reduce_sum(csr_tensor, 1)
|
||||
dense2 = F.csr_mv(csr_tensor, dense_vector)
|
||||
return dense1, dense2
|
||||
dense3 = csr_tensor.mm(dense_matrix)
|
||||
return dense1, dense2, dense3
|
||||
|
||||
def test_ops_pynative_sparse():
|
||||
sparse1 = csr_tensor * dense_tensor
|
||||
|
@ -359,10 +361,13 @@ def test_csr_ops():
|
|||
graph_res_dense = test_ops_graph_dense()
|
||||
expect1 = np.array([[2.], [1.]], dtype=np.float32)
|
||||
expect2 = np.array([[2.], [1.]], dtype=np.float32)
|
||||
expect3 = np.array([[2., 4.], [1., 2.]], dtype=np.float32)
|
||||
assert np.allclose(pynative_res_dense[0].asnumpy(), expect1)
|
||||
assert np.allclose(pynative_res_dense[1].asnumpy(), expect2)
|
||||
assert np.allclose(pynative_res_dense[2].asnumpy(), expect3)
|
||||
assert np.allclose(graph_res_dense[0].asnumpy(), expect1)
|
||||
assert np.allclose(graph_res_dense[1].asnumpy(), expect2)
|
||||
assert np.allclose(graph_res_dense[2].asnumpy(), expect3)
|
||||
|
||||
pynative_res_sparse = test_ops_pynative_sparse()
|
||||
graph_res_sparse = test_ops_graph_sparse()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -23,7 +23,6 @@ square = P.Square()
|
|||
sqrt = P.Sqrt()
|
||||
real_div = P.RealDiv()
|
||||
sub = P.Sub()
|
||||
Assign = P.Assign()
|
||||
make_tuple = Primitive('MakeTuple')
|
||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
|
||||
|
|
Loading…
Reference in New Issue