forked from mindspore-Ecosystem/mindspore
!1250 Dict_setitem transofrm to tuple_setitem
Merge pull request !1250 from amongo/SupportDictSetItemTransform
This commit is contained in:
commit
233508b70e
|
@ -139,6 +139,47 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
|
|||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
// Inputs should be [dict_setitem, dict, item, value]
|
||||
const auto &inputs = node->inputs();
|
||||
MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs.");
|
||||
|
||||
AnfNodePtr data = inputs[1];
|
||||
AnfNodePtr cons = inputs[2];
|
||||
AnfNodePtr item_value = inputs[3];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(cons);
|
||||
|
||||
auto dt = data->abstract();
|
||||
MS_EXCEPTION_IF_NULL(dt);
|
||||
if (!dt->isa<abstract::AbstractDictionary>()) {
|
||||
MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
|
||||
}
|
||||
auto cons_is_str = IsValueNode<StringImm>(cons);
|
||||
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
|
||||
|
||||
auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
|
||||
const auto &cmap = ct->elements();
|
||||
int count = 0;
|
||||
for (auto &item : cmap) {
|
||||
if (cons_is_str && item.first == cons_str) {
|
||||
break;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
if (IntToSize(count) >= cmap.size()) {
|
||||
MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str
|
||||
<< " does not exist, can not create new dictionary item for now.";
|
||||
}
|
||||
auto idx_c = NewValueNode(count);
|
||||
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
|
||||
idx_c->set_abstract(aptr);
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
@ -300,6 +341,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
new_node = ErasePartialNode(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
|
||||
new_node = ConvertDictGetItemToTupleGetItem(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
|
||||
new_node = ConvertDictSetItemToTupleSetItem(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
|
||||
new_node = EraseMakeDictNode(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
|
||||
|
|
|
@ -138,7 +138,7 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
if (key1_ == key2_) {
|
||||
return last_;
|
||||
}
|
||||
return fg->NewCNode({op_, tuple_, c2_});
|
||||
return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_});
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
if (cnode->size() < 4) {
|
||||
return;
|
||||
}
|
||||
op_ = cnode->input(0);
|
||||
|
||||
tuple_ = cnode->input(1);
|
||||
last_ = cnode->input(3);
|
||||
|
||||
|
@ -174,7 +174,6 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
void Reset() {
|
||||
key1_ = -1;
|
||||
key2_ = -1;
|
||||
op_ = nullptr;
|
||||
c2_ = nullptr;
|
||||
last_ = nullptr;
|
||||
tuple_ = nullptr;
|
||||
|
@ -184,7 +183,7 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
private:
|
||||
bool is_in_set_{false};
|
||||
int key1_{-1}, key2_{-1};
|
||||
AnfNodePtr op_{nullptr}, tuple_{nullptr}, last_{nullptr}, c2_{nullptr};
|
||||
AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
||||
|
|
|
@ -136,3 +136,17 @@ def test_dict_set_or_get_item_3():
|
|||
|
||||
net = DictNet()
|
||||
assert net() == Tensor(np.ones([4, 2, 3], np.float32))
|
||||
|
||||
def test_dict_set_item():
|
||||
class DictSetNet(Cell):
|
||||
def __init__(self):
|
||||
super(DictSetNet, self).__init__()
|
||||
self.attrs = ("abc", "edf", "ghi", "jkl")
|
||||
def construct(self, x):
|
||||
my_dict = {"def": x, "abc":x, "edf":x, "ghi":x, "jkl":x}
|
||||
for i in range(len(self.attrs)):
|
||||
my_dict[self.attrs[i]] = x - i
|
||||
return my_dict["jkl"], my_dict["edf"]
|
||||
x = Tensor(np.ones([2, 2, 3], np.float32))
|
||||
net = DictSetNet()
|
||||
out = net(x)
|
Loading…
Reference in New Issue