forked from OSSInnovation/mindspore
make abstractref can join with abstracttensor
This commit is contained in:
parent
1f4944fa15
commit
9ad7c652a2
|
@ -838,7 +838,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
|
|||
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
|
||||
auto other_ref = other->cast<AbstractRefPtr>();
|
||||
if (other_ref == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
auto new_ref = ref_->Join(other);
|
||||
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
|
||||
}
|
||||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
def test_net_vargs_expand():
|
||||
|
@ -184,6 +184,27 @@ def test_grad_var_args_with_sens():
|
|||
_ = grad_net(x, y, sens)
|
||||
|
||||
|
||||
def test_grad_with_param_sens():
|
||||
""""test grad_with_sens parameter"""
|
||||
|
||||
class GradNet(Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
self.net = net
|
||||
self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False)
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.grad(self.net, self.weights)(x, y, self.sens)
|
||||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
net = SecondNet()
|
||||
grad_net = GradNet(net)
|
||||
_ = grad_net(x, y)
|
||||
|
||||
|
||||
def test_var_args_grad():
|
||||
class VarNet(Cell):
|
||||
def __init__(self, net):
|
||||
|
|
Loading…
Reference in New Issue