forked from mindspore-Ecosystem/mindspore
get keys and values from dictionary & set tuple to dictionary
This commit is contained in:
parent
d2b1e783e7
commit
3b21822824
|
@ -304,6 +304,13 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
|
|||
return inputs[2];
|
||||
}
|
||||
|
||||
AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &inputs = node->inputs();
|
||||
MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs");
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &inputs = node->inputs();
|
||||
|
@ -374,6 +381,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
new_node = ConvertDictGetItemToTupleGetItem(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
|
||||
new_node = ConvertDictSetItemToTupleSetItem(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
|
||||
new_node = EraseDictGetValues(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
|
||||
new_node = EraseMakeDictNode(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
|
||||
|
|
|
@ -141,6 +141,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
|
|
|
@ -131,6 +131,10 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -249,6 +249,32 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
AbstractBasePtrList keys;
|
||||
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
|
||||
[](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); });
|
||||
return std::make_shared<AbstractTuple>(keys);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
AbstractBasePtrList values;
|
||||
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values),
|
||||
[](const AbstractAttribute &item) { return item.second; });
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||
|
|
|
@ -72,6 +72,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
|
||||
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
|
||||
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
|
||||
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}},
|
||||
{prim::kPrimDictGetValues, {InferImplDictGetValues, true}},
|
||||
{prim::kPrimListAppend, {InferImplListAppend, true}},
|
||||
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
|
||||
{prim::kPrimListLen, {InferImplListLen, true}},
|
||||
|
|
|
@ -279,6 +279,8 @@ inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_g
|
|||
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
|
||||
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
|
||||
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
|
||||
inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys");
|
||||
inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues");
|
||||
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
|
||||
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
|
||||
|
|
|
@ -132,6 +132,20 @@ def _dict_setitem_with_number(data, key, value):
|
|||
"""
|
||||
return F.dict_setitem(data, key, value)
|
||||
|
||||
@setitem.register("Dictionary", "String", "Tuple")
|
||||
def _dict_setitem_with_tuple(data, key, value):
|
||||
"""
|
||||
Assigns value to dictionary.
|
||||
|
||||
Inputs:
|
||||
data (dict): Data of type dict.
|
||||
key (str): Key of the data.
|
||||
value (Tuple): Value given.
|
||||
|
||||
Outputs:
|
||||
dict, type is as same as the element type of data.
|
||||
"""
|
||||
return F.dict_setitem(data, key, value)
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Tensor")
|
||||
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dictionary """
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class Net1(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
dic = {'x': 0, 'y': 1}
|
||||
output = []
|
||||
for i in dic.keys():
|
||||
output.append(i)
|
||||
for j in dic.values():
|
||||
output.append(j)
|
||||
return output
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
dic = {'x': x, 'y': 1}
|
||||
output = []
|
||||
for i in dic.keys():
|
||||
output.append(i)
|
||||
for j in dic.values():
|
||||
output.append(j)
|
||||
return output
|
||||
|
||||
class Net3(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
dic = {'x': 0}
|
||||
dic['y'] = (0, 1)
|
||||
output = []
|
||||
for i in dic.keys():
|
||||
output.append(i)
|
||||
for j in dic.values():
|
||||
output.append(j)
|
||||
return output
|
||||
|
||||
def test_dict1():
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net = Net1()
|
||||
out_me = net(input_me)
|
||||
assert out_me == ('x', 'y', 0, 1)
|
||||
|
||||
|
||||
def test_dict2():
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net = Net2()
|
||||
net(input_me)
|
||||
|
||||
def test_dict3():
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net = Net3()
|
||||
out_me = net(input_me)
|
||||
assert out_me == ('x', 'y', 0, (0, 1))
|
Loading…
Reference in New Issue