forked from mindspore-Ecosystem/mindspore
!3537 add sparse operators
Merge pull request !3537 from riemann_penn/add_sparse_operator
This commit is contained in:
commit
f1e717554c
|
@ -32,9 +32,11 @@ namespace opt {
|
|||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractIndexedSlices;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSparseTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractUndetermined;
|
||||
|
||||
|
@ -73,6 +75,19 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
|
|||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||
}
|
||||
|
||||
if (t->isa<AbstractSparseTensor>()) {
|
||||
auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
|
||||
return std::make_shared<AbstractTuple>(abstract_list);
|
||||
}
|
||||
|
||||
if (t->isa<AbstractIndexedSlices>()) {
|
||||
auto abs_indexed_slices = dyn_cast<AbstractIndexedSlices>(t);
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(),
|
||||
abs_indexed_slices->dense_shape()};
|
||||
return std::make_shared<AbstractTuple>(abstract_list);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -389,14 +404,44 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
return changed;
|
||||
}
|
||||
|
||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
// Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
|
||||
(void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
|
||||
return node->func_graph()->NewCNode(inputs);
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int &index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
const auto &inputs = node->inputs();
|
||||
// Inputs should be [spase_getattr, sparse]
|
||||
if (inputs.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Node's input number < 2.";
|
||||
}
|
||||
|
||||
AnfNodePtr sparse = inputs[1];
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
auto cons_node = NewValueNode(index);
|
||||
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(index));
|
||||
cons_node->set_abstract(aptr);
|
||||
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
|
||||
}
|
||||
|
||||
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(root);
|
||||
|
||||
bool changed = false;
|
||||
|
||||
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||
AnfNodeSet all_node = manager->all_nodes();
|
||||
auto all_node = manager->all_nodes();
|
||||
for (auto &node : all_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -409,6 +454,18 @@ bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
|||
new_node = ConvertListSetItemToTupleSetItem(cnode);
|
||||
} else if (IsValueNode<ValueList>(node)) {
|
||||
new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) {
|
||||
new_node = ConvertMakeSparseToMakeTuple(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) {
|
||||
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2);
|
||||
}
|
||||
|
||||
if (new_node != nullptr) {
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace opt {
|
|||
|
||||
// Remove the class type from graphs
|
||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
|
||||
// Remove most uses of tuples from the graph
|
||||
// tuples that are returned will be kept
|
||||
|
|
|
@ -69,11 +69,11 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CleanListPass(const ResourcePtr &res) {
|
||||
bool CleanAfterOptAPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
bool changed = opt::CleanList(func_graph, res->manager());
|
||||
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
|
||||
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
|
@ -337,7 +337,7 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
|||
|
||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_list", CleanListPass},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"cconv", CconvPass},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
|
@ -346,7 +346,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_list", CleanListPass},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"opt_control", ControlGroup},
|
||||
|
|
|
@ -17,13 +17,14 @@ Neural Networks Cells.
|
|||
|
||||
Pre-defined building blocks or computing units to construct Neural Networks.
|
||||
"""
|
||||
from . import layer, loss, optim, metrics, wrap, probability
|
||||
from . import layer, loss, optim, metrics, wrap, probability, sparse
|
||||
from .cell import Cell, GraphKernel
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .optim import *
|
||||
from .metrics import *
|
||||
from .wrap import *
|
||||
from .sparse import *
|
||||
|
||||
|
||||
__all__ = ["Cell", "GraphKernel"]
|
||||
|
@ -32,7 +33,7 @@ __all__.extend(loss.__all__)
|
|||
__all__.extend(optim.__all__)
|
||||
__all__.extend(metrics.__all__)
|
||||
__all__.extend(wrap.__all__)
|
||||
|
||||
__all__.extend(sparse.__all__)
|
||||
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Sparse related transformation.
|
||||
"""
|
||||
from .sparse import SparseToDense
|
||||
|
||||
__all__ = [
|
||||
"SparseToDense",
|
||||
]
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Sparse related tools."""
|
||||
from mindspore.ops import operations as P
|
||||
from ..cell import Cell
|
||||
|
||||
|
||||
class SparseToDense(Cell):
|
||||
"""
|
||||
Convert a sparse tensor into dense.
|
||||
|
||||
Not yet supported by any backend at the moment.
|
||||
|
||||
Args:
|
||||
sparse_tensor (SparseTensor): the sparse tensor to convert.
|
||||
|
||||
Returns:
|
||||
Tensor, the tensor converted.
|
||||
|
||||
Examples:
|
||||
>>> class SparseToDenseCell(nn.Cell):
|
||||
>>> def __init__(self, dense_shape):
|
||||
>>> super(SparseToDenseCell, self).__init__()
|
||||
>>> self.dense_shape = dense_shape
|
||||
>>> self.sparse_to_dense = nn.SparseToDense()
|
||||
>>> def construct(self, indices, values):
|
||||
>>> sparse = SparseTensor(indices, values, self.dense_shape)
|
||||
>>> return self.sparse_to_dense(sparse)
|
||||
>>>
|
||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> dense_shape = (3, 4)
|
||||
>>> SparseToDenseCell(dense_shape)(indices, values)
|
||||
"""
|
||||
def __init__(self):
|
||||
super(SparseToDense, self).__init__()
|
||||
self.sparse_to_dense = P.SparseToDense()
|
||||
|
||||
def construct(self, sparse_tensor):
|
||||
return self.sparse_to_dense(sparse_tensor.indices(),
|
||||
sparse_tensor.values(),
|
||||
sparse_tensor.dense_shape())
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""grad impl."""
|
||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
||||
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse
|
||||
from .grad_base import get_bprop_fn
|
||||
|
||||
__all__ = ['get_bprop_fn']
|
||||
|
|
|
@ -116,6 +116,7 @@ def bprop_tuple_getitem(data, idx, out, dout):
|
|||
"""Backpropagator for primitive `tuple_getitem`."""
|
||||
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||
|
||||
|
||||
@bprops.register("list_getitem")
|
||||
def bprop_list_getitem(data, idx, out, dout):
|
||||
"""Backpropagator for primitive `list_getitem`."""
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""bprop primitives"""
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import bprops, bprop_getters
|
||||
|
||||
# Unused parameters are placeholders.
|
||||
|
||||
|
||||
@bprops.register("MakeSparseTensor")
|
||||
def bprop_make_sparse_tensor(indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeSparseTensor`."""
|
||||
return zeros_like(indices), F.sparse_tensor_get_values(dout), ()
|
||||
|
||||
|
||||
@bprops.register("SparseTensorGetIndices")
|
||||
def bprop_sparse_tensor_get_indices(sparse_tensor, out, dout):
|
||||
"""Backpropagator for primitive `SparseTensorGetIndices`."""
|
||||
return (zeros_like(sparse_tensor),)
|
||||
|
||||
|
||||
@bprops.register("SparseTensorGetValues")
|
||||
def bprop_sparse_tensor_get_values(sparse_tensor, out, dout):
|
||||
"""Backpropagator for primitive `SparseTensorGetValues`."""
|
||||
return F.make_sparse_tensor(F.sparse_tensor_get_indices(sparse_tensor),
|
||||
dout,
|
||||
F.sparse_tensor_get_dense_shape(sparse_tensor))
|
||||
|
||||
|
||||
@bprops.register("SparseTensorGetDenseShape")
|
||||
def bprop_sparse_tensor_get_dense_shape(sparse_tensor, out, dout):
|
||||
"""Backpropagator for primitive `SparseTensorGetDenseShape`."""
|
||||
return (zeros_like(sparse_tensor),)
|
||||
|
||||
|
||||
@bprop_getters.register(P.SparseToDense)
|
||||
def get_bprop_sparse_to_dense(self):
|
||||
"""Generate bprop for SparseToDense"""
|
||||
|
||||
def bprop(indices, values, dense_shape, out, dout):
|
||||
return zeros_like(indices), dout, zeros_like(dense_shape)
|
||||
|
||||
return bprop
|
|
@ -42,6 +42,16 @@ def _ones_like_tensor(x):
|
|||
return P.Fill()(P.DType()(x), P.Shape()(x), 1.0)
|
||||
|
||||
|
||||
@ones_like_leaf.register("SparseTensor")
|
||||
def _ones_like_sparse_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||
values_ = F.sparse_tensor_get_values(x)
|
||||
values = P.Fill()(P.DType()(values_),
|
||||
P.Shape()(values_),
|
||||
1.0)
|
||||
return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x))
|
||||
|
||||
|
||||
ones_like = base.HyperMap(ones_like_leaf)
|
||||
"""
|
||||
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
|
||||
|
|
|
@ -84,6 +84,7 @@ from ._quant_ops import *
|
|||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
||||
from .thor_ops import *
|
||||
from .sparse_ops import SparseToDense
|
||||
|
||||
__all__ = [
|
||||
'ReverseSequence',
|
||||
|
@ -357,7 +358,8 @@ __all__ = [
|
|||
"PopulationCount",
|
||||
"ParallelConcat",
|
||||
"Push",
|
||||
"Pull"
|
||||
"Pull",
|
||||
'SparseToDense',
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# coding: utf-8
|
||||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Operators for sparse operators."""
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
class SparseToDense(PrimitiveWithInfer):
|
||||
"""
|
||||
Convert a sparse representation into a dense tensor.
|
||||
|
||||
Inputs:
|
||||
- **indices** (Tensor) - The indices of sparse representation.
|
||||
- **values** (Tensor) - Values corresponding to each row of indices.
|
||||
- **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of tensor is dense_shape.
|
||||
|
||||
Examples:
|
||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> dense_shape = (3, 4)
|
||||
>>> out = P.SparseToDense()(indices, values, dense_shape)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init index_select"""
|
||||
self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
|
||||
|
||||
def __infer__(self, indices, values, dense_shape):
|
||||
validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name)
|
||||
validator.check_subclass("values", values['dtype'], mstype.tensor, self.name)
|
||||
out = {'shape': dense_shape['value'],
|
||||
'dtype': values['dtype'],
|
||||
'value': None}
|
||||
return out
|
|
@ -28,6 +28,7 @@ from mindspore import Tensor, SparseTensor, context
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
grad_op = C.GradOperation('get_all', get_all=True)
|
||||
|
||||
class MakeSparseTensor(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
|
@ -45,15 +46,6 @@ def test_sparse_tensor_make_sparse_tensor():
|
|||
|
||||
|
||||
def test_sparse_tensor_attr():
|
||||
grad_op = C.GradOperation('get_all', get_all=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
def construct(self, input1, input2):
|
||||
gout = grad_op(self.network)(input1, input2)
|
||||
return gout
|
||||
|
||||
class SparseTensorGetAttr(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseTensorGetAttr, self).__init__()
|
||||
|
@ -82,3 +74,20 @@ def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
|
|||
dense_shape = (2, 2, 2)
|
||||
with pytest.raises(TypeError):
|
||||
MakeSparseTensor(dense_shape)(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_to_tensor():
|
||||
class SparseToDenseCell(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super(SparseToDenseCell, self).__init__()
|
||||
self.dense_shape = dense_shape
|
||||
self.sparse_to_dense = nn.SparseToDense()
|
||||
def construct(self, indices, values):
|
||||
sparse = SparseTensor(indices, values, self.dense_shape)
|
||||
return self.sparse_to_dense(sparse)
|
||||
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
dense_shape = (3, 4)
|
||||
SparseToDenseCell(dense_shape)(indices, values)
|
||||
grad_op(SparseToDenseCell(dense_shape))(indices, values)
|
||||
|
|
|
@ -102,7 +102,7 @@ def test_with_no_bprop():
|
|||
with_no_bprop = WithNoBprop()
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
y = Tensor(2, dtype=ms.int32)
|
||||
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
|
||||
C.grad_all(with_no_bprop)(x, y)
|
||||
|
||||
|
||||
def test_grad_in_bprop_1():
|
||||
|
@ -263,10 +263,7 @@ def test_grad_inline_bprop_two_input():
|
|||
net = InlineBpropTwoInput()
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
grads = C.grad_all(net)(input1, input2)
|
||||
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
assert len(grads) == 2
|
||||
C.grad_all(net)(input1, input2)
|
||||
|
||||
|
||||
class TwoInputBprop(nn.Cell):
|
||||
|
@ -350,24 +347,6 @@ def test_refkey_bprop():
|
|||
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
|
||||
|
||||
|
||||
class MulAddWithWrongOutputNum(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAddWithWrongOutputNum, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return (2 * dout,)
|
||||
|
||||
|
||||
def test_grad_mul_add_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputNum()
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(mul_add)(1, 2)
|
||||
|
||||
|
||||
class MulAddWithWrongOutputType(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAddWithWrongOutputType, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue