!40054 Enable dict in dict setitem

Merge pull request !40054 from LiangZhibo/dict
This commit is contained in:
i-robot 2022-08-11 01:12:10 +00:00 committed by Gitee
commit afb16cbba4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 68 additions and 5 deletions

View File

@ -246,6 +246,22 @@ def _dict_setitem_with_list(data, key, value):
return F.dict_setitem(data, key, value)
@setitem.register("Dictionary", "String", "Dictionary")
def _dict_setitem_with_dict(data, key, value):
"""
Assigns value to dictionary.
Inputs:
data (dict): Data of type dict.
key (str): Key of the data.
value (dict): Value given.
Outputs:
dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
@setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value):
"""

View File

@ -98,9 +98,6 @@ def test_dict_set_or_get_item_2():
class DictNet(Cell):
"""DictNet1 definition"""
def __init__(self):
super(DictNet, self).__init__()
def construct(self):
tuple_1 = (1, 2, 3)
tuple_2 = (4, 5, 6)
@ -158,6 +155,58 @@ def test_dict_set_item():
_ = net(x)
def test_dict_set_item_2():
"""
Description: test dict in dict set item.
Expectation: the results are as expected.
"""
class DictSetNet(Cell):
def construct(self):
cur_dict = {"a": {"a0": 0, "a1": 1}}
cur_dict["a"]["a0"] = 3
cur_dict["a"]["a3"] = 3
cur_dict["b"] = {"b0": 0, "b1": 1}
return cur_dict
net = DictSetNet()
output = net()
assert len(output) == 2
first = output[0]
second = output[1]
assert len(first) == 3
assert first[0] == 3
assert first[1] == 1
assert first[2] == 3
assert len(second) == 2
assert second[0] == 0
assert second[1] == 1
def test_dict_set_item_3():
"""
Description: test dict in dict set item.
Expectation: the results are as expected.
"""
class DictSetNet(Cell):
def construct(self):
cur_dict = {"a": {"a0": {"a00": 0, "a01": 1}}, "b": 1}
cur_dict["a"]["a0"]["a00"] = 3
cur_dict["a"]["a0"]["a01"] = 3
return cur_dict
net = DictSetNet()
output = net()
assert len(output) == 2
first = output[0]
assert len(first) == 1
assert len(first[0]) == 2
assert first[0][0] == 3
assert first[0][1] == 3
# if the dictionary item does not exist, create a new one
def test_dict_set_item_create_new():
class DictSetNet(Cell):
@ -181,8 +230,6 @@ def test_dict_items():
"""
class DictItemsNet(Cell):
def __init__(self):
super(DictItemsNet, self).__init__()
def construct(self, x):
return x.items()