From e9aa280f86ae96b7e718971a8caa6c8d74b9e7f2 Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Sat, 23 May 2020 09:44:12 +0800 Subject: [PATCH] add support for dictionary set item which does not exist --- mindspore/ccsrc/optimizer/clean.cc | 7 +++++-- tests/ut/python/dtype/test_dictionary.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 4ada67893e0..fafe26e2ed0 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -24,6 +24,7 @@ #include #include "./common.h" #include "debug/trace.h" +#include "operator/composite/composite.h" namespace mindspore { /* namespace to support opt */ @@ -171,8 +172,10 @@ AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { count++; } if (IntToSize(count) >= cmap.size()) { - MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str - << " does not exist, can not create new dictionary item for now."; + // for dictionary set, if the key does not exist, we should create a new item + auto tuple_add_op = std::make_shared("tuple_add"); + auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); + return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); } auto idx_c = NewValueNode(count); AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); diff --git a/tests/ut/python/dtype/test_dictionary.py b/tests/ut/python/dtype/test_dictionary.py index 049ad340655..26c955d89de 100644 --- a/tests/ut/python/dtype/test_dictionary.py +++ b/tests/ut/python/dtype/test_dictionary.py @@ -153,3 +153,19 @@ def test_dict_set_item(): x = Tensor(np.ones([2, 2, 3], np.float32)) net = DictSetNet() out = net(x) + + +# if the dictionary item does not exist, create a new one +def test_dict_set_item_create_new(): + class DictSetNet(Cell): + def __init__(self): + super(DictSetNet, self).__init__() + self.attrs = ("abc", "edf", "ghi", "jkl") + def construct(self, x): + my_dict = {"def": x} + for i in range(len(self.attrs)): + my_dict[self.attrs[i]] = x - i + return my_dict + x = Tensor(np.ones([2, 2, 3], np.float32)) + net = DictSetNet() + out = net(x)