!1250 Dict_setitem transofrm to tuple_setitem

Merge pull request !1250 from amongo/SupportDictSetItemTransform
This commit is contained in:
mindspore-ci-bot 2020-05-20 21:37:26 +08:00 committed by Gitee
commit 233508b70e
3 changed files with 60 additions and 4 deletions

View File

@ -139,6 +139,47 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
}
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
// Inputs should be [dict_setitem, dict, item, value]
const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs.");
AnfNodePtr data = inputs[1];
AnfNodePtr cons = inputs[2];
AnfNodePtr item_value = inputs[3];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons);
auto dt = data->abstract();
MS_EXCEPTION_IF_NULL(dt);
if (!dt->isa<abstract::AbstractDictionary>()) {
MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
}
auto cons_is_str = IsValueNode<StringImm>(cons);
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
const auto &cmap = ct->elements();
int count = 0;
for (auto &item : cmap) {
if (cons_is_str && item.first == cons_str) {
break;
}
count++;
}
if (IntToSize(count) >= cmap.size()) {
MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str
<< " does not exist, can not create new dictionary item for now.";
}
auto idx_c = NewValueNode(count);
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
idx_c->set_abstract(aptr);
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
}
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -300,6 +341,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node = ErasePartialNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
new_node = ConvertDictGetItemToTupleGetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
new_node = ConvertDictSetItemToTupleSetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
new_node = EraseMakeDictNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {

View File

@ -138,7 +138,7 @@ class GetSetitemEliminater : public AnfVisitor {
if (key1_ == key2_) {
return last_;
}
return fg->NewCNode({op_, tuple_, c2_});
return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_});
}
return nullptr;
}
@ -148,7 +148,7 @@ class GetSetitemEliminater : public AnfVisitor {
if (cnode->size() < 4) {
return;
}
op_ = cnode->input(0);
tuple_ = cnode->input(1);
last_ = cnode->input(3);
@ -174,7 +174,6 @@ class GetSetitemEliminater : public AnfVisitor {
void Reset() {
key1_ = -1;
key2_ = -1;
op_ = nullptr;
c2_ = nullptr;
last_ = nullptr;
tuple_ = nullptr;
@ -184,7 +183,7 @@ class GetSetitemEliminater : public AnfVisitor {
private:
bool is_in_set_{false};
int key1_{-1}, key2_{-1};
AnfNodePtr op_{nullptr}, tuple_{nullptr}, last_{nullptr}, c2_{nullptr};
AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr};
};
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->

View File

@ -136,3 +136,17 @@ def test_dict_set_or_get_item_3():
net = DictNet()
assert net() == Tensor(np.ones([4, 2, 3], np.float32))
def test_dict_set_item():
class DictSetNet(Cell):
def __init__(self):
super(DictSetNet, self).__init__()
self.attrs = ("abc", "edf", "ghi", "jkl")
def construct(self, x):
my_dict = {"def": x, "abc":x, "edf":x, "ghi":x, "jkl":x}
for i in range(len(self.attrs)):
my_dict[self.attrs[i]] = x - i
return my_dict["jkl"], my_dict["edf"]
x = Tensor(np.ones([2, 2, 3], np.float32))
net = DictSetNet()
out = net(x)