fix valuenode simplify

This commit is contained in:
panyifeng 2020-07-14 17:54:11 +08:00
parent bc0a53cfb1
commit 394178569e
2 changed files with 26 additions and 9 deletions

View File

@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
return nullptr;
}
AbstractBasePtr res = t;
if (t->isa<AbstractClass>()) {
auto abs_class = dyn_cast<AbstractClass>(t);
AbstractBasePtrList baselist;
auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) {
return std::make_shared<AbstractTuple>(baselist);
}
if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist;
auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) {
auto abs_dict = dyn_cast<AbstractList>(t);
res = std::make_shared<AbstractTuple>(abs_dict->elements());
return std::make_shared<AbstractTuple>(baselist);
}
return res;
if (t->isa<AbstractList>()) {
auto abs_list = dyn_cast<AbstractList>(t);
return std::make_shared<AbstractTuple>(abs_list->elements());
}
return nullptr;
}
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract());
node->set_abstract(ret);
if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString();
node->set_abstract(ret);
changed = true;
}
}
return changed;
}

View File

@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
def test_dict_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.res = {'1': 10}
def construct(self):
return self.res
Net()()