!3537 add sparse operators
Merge pull request !3537 from riemann_penn/add_sparse_operator
This commit is contained in:
commit
f1e717554c
|
@ -32,9 +32,11 @@ namespace opt {
|
||||||
using mindspore::abstract::AbstractAttribute;
|
using mindspore::abstract::AbstractAttribute;
|
||||||
using mindspore::abstract::AbstractClass;
|
using mindspore::abstract::AbstractClass;
|
||||||
using mindspore::abstract::AbstractDictionary;
|
using mindspore::abstract::AbstractDictionary;
|
||||||
|
using mindspore::abstract::AbstractIndexedSlices;
|
||||||
using mindspore::abstract::AbstractJTagged;
|
using mindspore::abstract::AbstractJTagged;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractScalar;
|
using mindspore::abstract::AbstractScalar;
|
||||||
|
using mindspore::abstract::AbstractSparseTensor;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
using mindspore::abstract::AbstractUndetermined;
|
using mindspore::abstract::AbstractUndetermined;
|
||||||
|
|
||||||
|
@ -73,6 +75,19 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
|
||||||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (t->isa<AbstractSparseTensor>()) {
|
||||||
|
auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
|
||||||
|
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
|
||||||
|
return std::make_shared<AbstractTuple>(abstract_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (t->isa<AbstractIndexedSlices>()) {
|
||||||
|
auto abs_indexed_slices = dyn_cast<AbstractIndexedSlices>(t);
|
||||||
|
std::vector<AbstractBasePtr> abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(),
|
||||||
|
abs_indexed_slices->dense_shape()};
|
||||||
|
return std::make_shared<AbstractTuple>(abstract_list);
|
||||||
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,14 +404,44 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
// Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
|
||||||
|
(void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
|
||||||
|
return node->func_graph()->NewCNode(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int &index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||||
|
|
||||||
|
const auto &inputs = node->inputs();
|
||||||
|
// Inputs should be [spase_getattr, sparse]
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "Node's input number < 2.";
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr sparse = inputs[1];
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse);
|
||||||
|
auto cons_node = NewValueNode(index);
|
||||||
|
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(index));
|
||||||
|
cons_node->set_abstract(aptr);
|
||||||
|
|
||||||
|
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->AddFuncGraph(root);
|
manager->AddFuncGraph(root);
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
||||||
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||||
AnfNodeSet all_node = manager->all_nodes();
|
auto all_node = manager->all_nodes();
|
||||||
for (auto &node : all_node) {
|
for (auto &node : all_node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
@ -409,6 +454,18 @@ bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
new_node = ConvertListSetItemToTupleSetItem(cnode);
|
new_node = ConvertListSetItemToTupleSetItem(cnode);
|
||||||
} else if (IsValueNode<ValueList>(node)) {
|
} else if (IsValueNode<ValueList>(node)) {
|
||||||
new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
|
new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
|
||||||
|
IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) {
|
||||||
|
new_node = ConvertMakeSparseToMakeTuple(cnode);
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
|
||||||
|
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) {
|
||||||
|
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0);
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
|
||||||
|
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) {
|
||||||
|
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1);
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
|
||||||
|
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) {
|
||||||
|
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (new_node != nullptr) {
|
if (new_node != nullptr) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace opt {
|
||||||
|
|
||||||
// Remove the class type from graphs
|
// Remove the class type from graphs
|
||||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
|
|
||||||
// Remove most uses of tuples from the graph
|
// Remove most uses of tuples from the graph
|
||||||
// tuples that are returned will be kept
|
// tuples that are returned will be kept
|
||||||
|
|
|
@ -69,11 +69,11 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CleanListPass(const ResourcePtr &res) {
|
bool CleanAfterOptAPass(const ResourcePtr &res) {
|
||||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
|
|
||||||
FuncGraphPtr func_graph = res->func_graph();
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
bool changed = opt::CleanList(func_graph, res->manager());
|
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
|
||||||
|
|
||||||
abstract::AbstractBasePtrList args_spec;
|
abstract::AbstractBasePtrList args_spec;
|
||||||
auto parameters = func_graph->parameters();
|
auto parameters = func_graph->parameters();
|
||||||
|
@ -337,7 +337,7 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"opt_a", OptPassAGroup},
|
{"opt_a", OptPassAGroup},
|
||||||
{"clean_list", CleanListPass},
|
{"clean_after_opta", CleanAfterOptAPass},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
{"cconv", CconvPass},
|
{"cconv", CconvPass},
|
||||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||||
|
@ -346,7 +346,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
||||||
|
|
||||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"opt_a", OptPassAGroup},
|
{"opt_a", OptPassAGroup},
|
||||||
{"clean_list", CleanListPass},
|
{"clean_after_opta", CleanAfterOptAPass},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
{"add_control_depend", AddControlDependPass},
|
{"add_control_depend", AddControlDependPass},
|
||||||
{"opt_control", ControlGroup},
|
{"opt_control", ControlGroup},
|
||||||
|
|
|
@ -17,13 +17,14 @@ Neural Networks Cells.
|
||||||
|
|
||||||
Pre-defined building blocks or computing units to construct Neural Networks.
|
Pre-defined building blocks or computing units to construct Neural Networks.
|
||||||
"""
|
"""
|
||||||
from . import layer, loss, optim, metrics, wrap, probability
|
from . import layer, loss, optim, metrics, wrap, probability, sparse
|
||||||
from .cell import Cell, GraphKernel
|
from .cell import Cell, GraphKernel
|
||||||
from .layer import *
|
from .layer import *
|
||||||
from .loss import *
|
from .loss import *
|
||||||
from .optim import *
|
from .optim import *
|
||||||
from .metrics import *
|
from .metrics import *
|
||||||
from .wrap import *
|
from .wrap import *
|
||||||
|
from .sparse import *
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Cell", "GraphKernel"]
|
__all__ = ["Cell", "GraphKernel"]
|
||||||
|
@ -32,7 +33,7 @@ __all__.extend(loss.__all__)
|
||||||
__all__.extend(optim.__all__)
|
__all__.extend(optim.__all__)
|
||||||
__all__.extend(metrics.__all__)
|
__all__.extend(metrics.__all__)
|
||||||
__all__.extend(wrap.__all__)
|
__all__.extend(wrap.__all__)
|
||||||
|
__all__.extend(sparse.__all__)
|
||||||
|
|
||||||
|
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Sparse related transformation.
|
||||||
|
"""
|
||||||
|
from .sparse import SparseToDense
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SparseToDense",
|
||||||
|
]
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Sparse related tools."""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from ..cell import Cell
|
||||||
|
|
||||||
|
|
||||||
|
class SparseToDense(Cell):
|
||||||
|
"""
|
||||||
|
Convert a sparse tensor into dense.
|
||||||
|
|
||||||
|
Not yet supported by any backend at the moment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_tensor (SparseTensor): the sparse tensor to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the tensor converted.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class SparseToDenseCell(nn.Cell):
|
||||||
|
>>> def __init__(self, dense_shape):
|
||||||
|
>>> super(SparseToDenseCell, self).__init__()
|
||||||
|
>>> self.dense_shape = dense_shape
|
||||||
|
>>> self.sparse_to_dense = nn.SparseToDense()
|
||||||
|
>>> def construct(self, indices, values):
|
||||||
|
>>> sparse = SparseTensor(indices, values, self.dense_shape)
|
||||||
|
>>> return self.sparse_to_dense(sparse)
|
||||||
|
>>>
|
||||||
|
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
>>> dense_shape = (3, 4)
|
||||||
|
>>> SparseToDenseCell(dense_shape)(indices, values)
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(SparseToDense, self).__init__()
|
||||||
|
self.sparse_to_dense = P.SparseToDense()
|
||||||
|
|
||||||
|
def construct(self, sparse_tensor):
|
||||||
|
return self.sparse_to_dense(sparse_tensor.indices(),
|
||||||
|
sparse_tensor.values(),
|
||||||
|
sparse_tensor.dense_shape())
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
"""grad impl."""
|
"""grad impl."""
|
||||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||||
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse
|
||||||
from .grad_base import get_bprop_fn
|
from .grad_base import get_bprop_fn
|
||||||
|
|
||||||
__all__ = ['get_bprop_fn']
|
__all__ = ['get_bprop_fn']
|
||||||
|
|
|
@ -116,6 +116,7 @@ def bprop_tuple_getitem(data, idx, out, dout):
|
||||||
"""Backpropagator for primitive `tuple_getitem`."""
|
"""Backpropagator for primitive `tuple_getitem`."""
|
||||||
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("list_getitem")
|
@bprops.register("list_getitem")
|
||||||
def bprop_list_getitem(data, idx, out, dout):
|
def bprop_list_getitem(data, idx, out, dout):
|
||||||
"""Backpropagator for primitive `list_getitem`."""
|
"""Backpropagator for primitive `list_getitem`."""
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""bprop primitives"""
|
||||||
|
from .. import functional as F
|
||||||
|
from .. import operations as P
|
||||||
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
from .grad_base import bprops, bprop_getters
|
||||||
|
|
||||||
|
# Unused parameters are placeholders.
|
||||||
|
|
||||||
|
|
||||||
|
@bprops.register("MakeSparseTensor")
|
||||||
|
def bprop_make_sparse_tensor(indices, values, dense_shape, out, dout):
|
||||||
|
"""Backpropagator for primitive `MakeSparseTensor`."""
|
||||||
|
return zeros_like(indices), F.sparse_tensor_get_values(dout), ()
|
||||||
|
|
||||||
|
|
||||||
|
@bprops.register("SparseTensorGetIndices")
|
||||||
|
def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout):
|
||||||
|
"""Backpropagator for primitive `SparseTensorGetIndices`."""
|
||||||
|
return (zeros_like(sparse_tensor),)
|
||||||
|
|
||||||
|
|
||||||
|
@bprops.register("SparseTensorGetValues")
|
||||||
|
def bprop_sparse_tensor_get_values(sparse_tensor, out, dout):
|
||||||
|
"""Backpropagator for primitive `SparseTensorGetValues`."""
|
||||||
|
return F.make_sparse_tensor(F.sparse_tensor_get_indices(sparse_tensor),
|
||||||
|
dout,
|
||||||
|
F.sparse_tensor_get_dense_shape(sparse_tensor))
|
||||||
|
|
||||||
|
|
||||||
|
@bprops.register("SparseTensorGetDenseShape")
|
||||||
|
def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout):
|
||||||
|
"""Backpropagator for primitive `SparseTensorGetDenseShape`."""
|
||||||
|
return (zeros_like(sparse_tensor),)
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.SparseToDense)
|
||||||
|
def get_bprop_sparse_to_dense(self):
|
||||||
|
"""Generate bprop for SparseToDense"""
|
||||||
|
|
||||||
|
def bprop(indices, values, dense_shape, out, dout):
|
||||||
|
return zeros_like(indices), dout, zeros_like(dense_shape)
|
||||||
|
|
||||||
|
return bprop
|
|
@ -42,6 +42,16 @@ def _ones_like_tensor(x):
|
||||||
return P.Fill()(P.DType()(x), P.Shape()(x), 1.0)
|
return P.Fill()(P.DType()(x), P.Shape()(x), 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@ones_like_leaf.register("SparseTensor")
|
||||||
|
def _ones_like_sparse_tensor(x):
|
||||||
|
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||||
|
values_ = F.sparse_tensor_get_values(x)
|
||||||
|
values = P.Fill()(P.DType()(values_),
|
||||||
|
P.Shape()(values_),
|
||||||
|
1.0)
|
||||||
|
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
||||||
|
|
||||||
|
|
||||||
ones_like = base.HyperMap(ones_like_leaf)
|
ones_like = base.HyperMap(ones_like_leaf)
|
||||||
"""
|
"""
|
||||||
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
||||||
|
|
|
@ -84,6 +84,7 @@ from ._quant_ops import *
|
||||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
||||||
from .thor_ops import *
|
from .thor_ops import *
|
||||||
|
from .sparse_ops import SparseToDense
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ReverseSequence',
|
'ReverseSequence',
|
||||||
|
@ -357,7 +358,8 @@ __all__ = [
|
||||||
"PopulationCount",
|
"PopulationCount",
|
||||||
"ParallelConcat",
|
"ParallelConcat",
|
||||||
"Push",
|
"Push",
|
||||||
"Pull"
|
"Pull",
|
||||||
|
'SparseToDense',
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Operators for sparse operators."""
|
||||||
|
|
||||||
|
from ..._checkparam import Validator as validator
|
||||||
|
from ...common import dtype as mstype
|
||||||
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
|
|
||||||
|
class SparseToDense(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Convert a sparse representation into a dense tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **indices** (Tensor) - The indices of sparse representation.
|
||||||
|
- **values** (Tensor) - Values corresponding to each row of indices.
|
||||||
|
- **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the shape of tensor is dense_shape.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
>>> dense_shape = (3, 4)
|
||||||
|
>>> out = P.SparseToDense()(indices, values, dense_shape)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init index_select"""
|
||||||
|
self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
|
||||||
|
|
||||||
|
def __infer__(self, indices, values, dense_shape):
|
||||||
|
validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name)
|
||||||
|
validator.check_subclass("values", values['dtype'], mstype.tensor, self.name)
|
||||||
|
out = {'shape': dense_shape['value'],
|
||||||
|
'dtype': values['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
|
@ -28,6 +28,7 @@ from mindspore import Tensor, SparseTensor, context
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||||
|
|
||||||
|
grad_op = C.GradOperation('get_all', get_all=True)
|
||||||
|
|
||||||
class MakeSparseTensor(nn.Cell):
|
class MakeSparseTensor(nn.Cell):
|
||||||
def __init__(self, dense_shape):
|
def __init__(self, dense_shape):
|
||||||
|
@ -45,15 +46,6 @@ def test_sparse_tensor_make_sparse_tensor():
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_tensor_attr():
|
def test_sparse_tensor_attr():
|
||||||
grad_op = C.GradOperation('get_all', get_all=True)
|
|
||||||
class GradWrap(nn.Cell):
|
|
||||||
def __init__(self, network):
|
|
||||||
super(GradWrap, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
def construct(self, input1, input2):
|
|
||||||
gout = grad_op(self.network)(input1, input2)
|
|
||||||
return gout
|
|
||||||
|
|
||||||
class SparseTensorGetAttr(nn.Cell):
|
class SparseTensorGetAttr(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SparseTensorGetAttr, self).__init__()
|
super(SparseTensorGetAttr, self).__init__()
|
||||||
|
@ -82,3 +74,20 @@ def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
|
||||||
dense_shape = (2, 2, 2)
|
dense_shape = (2, 2, 2)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
MakeSparseTensor(dense_shape)(indices, values)
|
MakeSparseTensor(dense_shape)(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_tensor_to_tensor():
|
||||||
|
class SparseToDenseCell(nn.Cell):
|
||||||
|
def __init__(self, dense_shape):
|
||||||
|
super(SparseToDenseCell, self).__init__()
|
||||||
|
self.dense_shape = dense_shape
|
||||||
|
self.sparse_to_dense = nn.SparseToDense()
|
||||||
|
def construct(self, indices, values):
|
||||||
|
sparse = SparseTensor(indices, values, self.dense_shape)
|
||||||
|
return self.sparse_to_dense(sparse)
|
||||||
|
|
||||||
|
indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
dense_shape = (3, 4)
|
||||||
|
SparseToDenseCell(dense_shape)(indices, values)
|
||||||
|
grad_op(SparseToDenseCell(dense_shape))(indices, values)
|
||||||
|
|
|
@ -102,7 +102,7 @@ def test_with_no_bprop():
|
||||||
with_no_bprop = WithNoBprop()
|
with_no_bprop = WithNoBprop()
|
||||||
x = Tensor(1, dtype=ms.int32)
|
x = Tensor(1, dtype=ms.int32)
|
||||||
y = Tensor(2, dtype=ms.int32)
|
y = Tensor(2, dtype=ms.int32)
|
||||||
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
|
C.grad_all(with_no_bprop)(x, y)
|
||||||
|
|
||||||
|
|
||||||
def test_grad_in_bprop_1():
|
def test_grad_in_bprop_1():
|
||||||
|
@ -263,10 +263,7 @@ def test_grad_inline_bprop_two_input():
|
||||||
net = InlineBpropTwoInput()
|
net = InlineBpropTwoInput()
|
||||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||||
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
|
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||||
grads = C.grad_all(net)(input1, input2)
|
C.grad_all(net)(input1, input2)
|
||||||
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
|
||||||
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
|
||||||
assert len(grads) == 2
|
|
||||||
|
|
||||||
|
|
||||||
class TwoInputBprop(nn.Cell):
|
class TwoInputBprop(nn.Cell):
|
||||||
|
@ -350,24 +347,6 @@ def test_refkey_bprop():
|
||||||
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||||
|
|
||||||
|
|
||||||
class MulAddWithWrongOutputNum(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(MulAddWithWrongOutputNum, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, x, y):
|
|
||||||
return 2 * x + y
|
|
||||||
|
|
||||||
def bprop(self, x, y, out, dout):
|
|
||||||
return (2 * dout,)
|
|
||||||
|
|
||||||
|
|
||||||
def test_grad_mul_add_with_wrong_output_num():
|
|
||||||
context.set_context(check_bprop=True)
|
|
||||||
mul_add = MulAddWithWrongOutputNum()
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
C.grad_all(mul_add)(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class MulAddWithWrongOutputType(nn.Cell):
|
class MulAddWithWrongOutputType(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MulAddWithWrongOutputType, self).__init__()
|
super(MulAddWithWrongOutputType, self).__init__()
|
||||||
|
|
Loading…
Reference in New Issue