forked from mindspore-Ecosystem/mindspore
fix valuenode simplify
This commit is contained in:
parent
bc0a53cfb1
commit
394178569e
|
@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr res = t;
|
||||
if (t->isa<AbstractClass>()) {
|
||||
auto abs_class = dyn_cast<AbstractClass>(t);
|
||||
AbstractBasePtrList baselist;
|
||||
auto attributes = abs_class->attributes();
|
||||
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
|
||||
[](const AbstractAttribute &item) { return item.second; });
|
||||
res = std::make_shared<AbstractTuple>(baselist);
|
||||
} else if (t->isa<AbstractDictionary>()) {
|
||||
return std::make_shared<AbstractTuple>(baselist);
|
||||
}
|
||||
if (t->isa<AbstractDictionary>()) {
|
||||
auto abs_dict = dyn_cast<AbstractDictionary>(t);
|
||||
AbstractBasePtrList baselist;
|
||||
auto elements = abs_dict->elements();
|
||||
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
|
||||
[](const AbstractAttribute &item) { return item.second; });
|
||||
res = std::make_shared<AbstractTuple>(baselist);
|
||||
} else if (t->isa<AbstractList>()) {
|
||||
auto abs_dict = dyn_cast<AbstractList>(t);
|
||||
res = std::make_shared<AbstractTuple>(abs_dict->elements());
|
||||
return std::make_shared<AbstractTuple>(baselist);
|
||||
}
|
||||
return res;
|
||||
if (t->isa<AbstractList>()) {
|
||||
auto abs_list = dyn_cast<AbstractList>(t);
|
||||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
|
||||
|
@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
|
||||
for (auto &node : manager->all_nodes()) {
|
||||
auto ret = Reabs(node->abstract());
|
||||
node->set_abstract(ret);
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||
<< ret->ToString();
|
||||
node->set_abstract(ret);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
|
|
@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline():
|
|||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||
grads = C.grad_all(network)(inp)
|
||||
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
||||
|
||||
|
||||
def test_dict_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.res = {'1': 10}
|
||||
def construct(self):
|
||||
return self.res
|
||||
Net()()
|
||||
|
|
Loading…
Reference in New Issue