make abstractref can join with abstracttensor

This commit is contained in:
zhousiyi 2020-07-07 11:25:37 +00:00
parent 1f4944fa15
commit 9ad7c652a2
2 changed files with 24 additions and 2 deletions

View File

@ -838,7 +838,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
}
if (*this == *other) {
return shared_from_base<AbstractBase>();

View File

@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.nn import Cell
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_net_vargs_expand():
@ -184,6 +184,27 @@ def test_grad_var_args_with_sens():
_ = grad_net(x, y, sens)
def test_grad_with_param_sens():
""""test grad_with_sens parameter"""
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False)
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
def construct(self, x, y):
return self.grad(self.net, self.weights)(x, y, self.sens)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = SecondNet()
grad_net = GradNet(net)
_ = grad_net(x, y)
def test_var_args_grad():
class VarNet(Cell):
def __init__(self, net):