!49589 [1.10] Fix dic.get error

Merge pull request !49589 from huanghui/r1.10-fix-dict-get
This commit is contained in:
i-robot 2023-03-03 14:05:19 +00:00 committed by Gitee
commit cce56aa486
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 873 additions and 25 deletions

View File

@ -0,0 +1,179 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/operator/composite/dict_operation.h"
#include <vector>
#include <utility>
#include <algorithm>
#include "abstract/param_validator.h"
#include "frontend/optimizer/opt.h"
#include "include/common/pybind_api/api_register.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
FuncGraphPtr DictClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
constexpr size_t dict_clear_args_size = 1;
abstract::CheckArgsSize("DictClear", args_list, dict_clear_args_size);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("clear");
(void)ret->add_parameter();
auto empty_dict = std::vector<std::pair<std::string, ValuePtr>>();
ret->set_output(NewValueNode(std::make_shared<ValueDictionary>(empty_dict)));
return ret;
}
FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
constexpr size_t dict_has_key_args_size = 2;
abstract::CheckArgsSize("DictHasKey", args_list, dict_has_key_args_size);
auto dict = dyn_cast<abstract::AbstractDictionary>(args_list[0]);
ValuePtr key_value = args_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(dict);
MS_EXCEPTION_IF_NULL(key_value);
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "The key should be string, but got " << key_value->ToString();
}
auto key_str = GetValue<std::string>(key_value);
auto dict_elems = dict->elements();
bool has_key = false;
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[&key_str](const abstract::AbstractAttribute &item) { return item.first == key_str; });
if (it != dict_elems.end()) {
has_key = true;
}
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("has_key");
(void)ret->add_parameter();
(void)ret->add_parameter();
auto out = NewValueNode(MakeValue(has_key));
ret->set_output(out);
return ret;
}
FuncGraphPtr DictUpdate::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
constexpr size_t dict_update_args_size = 2;
abstract::CheckArgsSize("DictUpdate", args_list, dict_update_args_size);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("update");
AnfNodePtrList key_inputs;
AnfNodePtrList value_inputs;
(void)key_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)value_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
std::unordered_map<std::string, size_t> hash_map;
AddNodeToLists(args_list[0], ret, &key_inputs, &value_inputs, &hash_map);
AddNodeToLists(args_list[1], ret, &key_inputs, &value_inputs, &hash_map);
ret->set_output(ret->NewCNode(
{NewValueNode(prim::kPrimMakeDict), ret->NewCNode(std::move(key_inputs)), ret->NewCNode(std::move(value_inputs))}));
return ret;
}
void DictUpdate::AddNodeToLists(const AbstractBasePtr &arg, const FuncGraphPtr &ret, AnfNodePtrList *keys,
AnfNodePtrList *values, std::unordered_map<std::string, size_t> *hash_map) {
auto dict = dyn_cast<abstract::AbstractDictionary>(arg);
MS_EXCEPTION_IF_NULL(dict);
auto &dict_elems = dict->elements();
auto arg_node = ret->add_parameter();
for (const auto &elem : dict_elems) {
AnfNodePtr dict_value = ret->NewCNode({NewValueNode(prim::kPrimDictGetItem), arg_node, NewValueNode(elem.first)});
auto map_find = hash_map->find(elem.first);
if (map_find == hash_map->end()) {
hash_map->insert(std::make_pair(elem.first, values->size()));
(void)keys->emplace_back(NewValueNode(elem.first));
(void)values->emplace_back(dict_value);
} else {
values->at(map_find->second) = dict_value;
}
}
}
FuncGraphPtr DictFromKeys::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
constexpr size_t dict_fromkeys_args_size = 3;
abstract::CheckArgsSize("DictFromKeys", args_list, dict_fromkeys_args_size);
const auto &values = ParseIterableObject(args_list[1]);
auto value_node = args_list[2]->BuildValue();
MS_EXCEPTION_IF_NULL(value_node);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("fromkeys");
(void)ret->add_parameter();
(void)ret->add_parameter();
(void)ret->add_parameter();
std::vector<std::pair<std::string, ValuePtr>> key_values;
for (auto &value : values) {
auto key = value->BuildValue();
if (!key->IsSameTypeId(StringImm::kTypeId)) {
MS_LOG(EXCEPTION) << "The key should be string, but got " << key->type_name();
}
std::string key_node = GetValue<std::string>(key);
(void)key_values.emplace_back(std::make_pair(key_node, value_node));
}
ret->set_output(NewValueNode(std::make_shared<ValueDictionary>(key_values)));
return ret;
}
abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::AbstractBasePtr &arg_key) {
auto key_type = arg_key->BuildType();
if (key_type->IsSameTypeId(List::kTypeId) || key_type->IsSameTypeId(Tuple::kTypeId)) {
abstract::AbstractSequencePtr dict_keys = dyn_cast<abstract::AbstractSequence>(arg_key);
MS_EXCEPTION_IF_NULL(dict_keys);
return dict_keys->elements();
}
if (key_type->IsSameTypeId(Dictionary::kTypeId)) {
auto dict_keys = dyn_cast<abstract::AbstractDictionary>(arg_key);
MS_EXCEPTION_IF_NULL(dict_keys);
AbstractBasePtrList keys;
auto &dict_elems = dict_keys->elements();
(void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
[](const abstract::AbstractAttribute &item) { return std::make_shared<abstract::AbstractScalar>(item.first); });
return keys;
}
if (key_type->IsSameTypeId(String::kTypeId)) {
auto value = arg_key->BuildValue();
MS_EXCEPTION_IF_NULL(value);
string dict_keys = value->ToString();
AbstractBasePtrList keys;
(void)std::transform(dict_keys.begin(), dict_keys.end(), std::back_inserter(keys), [](const char &item) {
return std::make_shared<abstract::AbstractScalar>(std::string(1, item));
});
return keys;
}
MS_LOG(EXCEPTION) << key_type->ToString() << " object is not iterable";
}
} // namespace prim
} // namespace mindspore

View File

@ -0,0 +1,90 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_
#include <string>
#include <unordered_map>
#include <memory>
#include "ir/meta_func_graph.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
class DictClear : public MetaFuncGraph {
public:
explicit DictClear(const std::string &name) : MetaFuncGraph(name) {}
~DictClear() override = default;
MS_DECLARE_PARENT(DictClear, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const DictClear &dict_clear) {
os << dict_clear.name_;
return os;
}
friend bool operator==(const DictClear &lhs, const DictClear &rhs) { return lhs.name_ == rhs.name_; }
};
using DictClearPtr = std::shared_ptr<DictClear>;
class DictHasKey : public MetaFuncGraph {
public:
explicit DictHasKey(const std::string &name) : MetaFuncGraph(name) {}
~DictHasKey() override = default;
MS_DECLARE_PARENT(DictHasKey, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const DictHasKey &dict_has_key) {
os << dict_has_key.name_;
return os;
}
friend bool operator==(const DictHasKey &lhs, const DictHasKey &rhs) { return lhs.name_ == rhs.name_; }
};
using DictHasKeyPtr = std::shared_ptr<DictHasKey>;
class DictUpdate : public MetaFuncGraph {
public:
explicit DictUpdate(const std::string &name) : MetaFuncGraph(name) {}
~DictUpdate() override = default;
MS_DECLARE_PARENT(DictUpdate, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const DictUpdate &dict_update) {
os << dict_update.name_;
return os;
}
friend bool operator==(const DictUpdate &lhs, const DictUpdate &rhs) { return lhs.name_ == rhs.name_; }
void AddNodeToLists(const AbstractBasePtr &arg, const FuncGraphPtr &ret, AnfNodePtrList *keys, AnfNodePtrList *values,
std::unordered_map<std::string, size_t> *hash_map);
};
using DictUpdatePtr = std::shared_ptr<DictUpdate>;
class DictFromKeys : public MetaFuncGraph {
public:
explicit DictFromKeys(const std::string &name) : MetaFuncGraph(name) {}
~DictFromKeys() override = default;
MS_DECLARE_PARENT(DictFromKeys, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const DictFromKeys &dict_from_keys) {
os << dict_from_keys.name_;
return os;
}
friend bool operator==(const DictFromKeys &lhs, const DictFromKeys &rhs) { return lhs.name_ == rhs.name_; }
abstract::AbstractBasePtrList ParseIterableObject(const abstract::AbstractBasePtr &arg_key);
};
using DictFromKeysPtr = std::shared_ptr<DictFromKeys>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_

View File

@ -18,6 +18,7 @@
#include "frontend/operator/composite/composite.h"
#include "include/common/pybind_api/api_register.h"
#include "frontend/operator/composite/list_operation.h"
#include "frontend/operator/composite/dict_operation.h"
#include "frontend/operator/composite/map.h"
#include "frontend/operator/composite/unpack_call.h"
#include "frontend/operator/composite/vmap.h"
@ -103,6 +104,22 @@ void RegCompositeOpsGroup(const py::module *m) {
(void)py::class_<ListCount, MetaFuncGraph, std::shared_ptr<ListCount>>(*m, "ListCount_")
.def(py::init<const std::string &>());
// Reg DictClear
(void)py::class_<DictClear, MetaFuncGraph, std::shared_ptr<DictClear>>(*m, "DictClear_")
.def(py::init<const std::string &>());
// Reg DictHasKey
(void)py::class_<DictHasKey, MetaFuncGraph, std::shared_ptr<DictHasKey>>(*m, "DictHasKey_")
.def(py::init<const std::string &>());
// Reg DictUpdate
(void)py::class_<DictUpdate, MetaFuncGraph, std::shared_ptr<DictUpdate>>(*m, "DictUpdate_")
.def(py::init<const std::string &>());
// Reg DictFromKeys
(void)py::class_<DictFromKeys, MetaFuncGraph, std::shared_ptr<DictFromKeys>>(*m, "DictFromKeys_")
.def(py::init<const std::string &>());
// Reg MapPy
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"), py::arg("ops"))

View File

@ -146,15 +146,19 @@ BuiltInTypeMap &GetMethodMap() {
}},
{kObjectTypeDictionary,
{
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items
{"__bool__", std::string("dict_bool")}, // C.dict_bool
{"get", std::string("dict_get")} // C.dict_get
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items
{"__bool__", std::string("dict_bool")}, // C.dict_bool
{"get", std::string("dict_get")}, // C.dict_get
{"has_key", std::string("dict_haskey")}, // C.dict_haskey
{"clear", std::string("dict_clear")}, // C.dict_clear
{"update", std::string("dict_update")}, // C.dict_update
{"fromkeys", std::string("dict_fromkeys")} // C.dict_fromkeys
}},
{kObjectTypeTensorType,
{

View File

@ -190,15 +190,13 @@ AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitiveP
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string.
const std::string op_name = primitive->name();
// dict[key] mean the size of args_spec_list is 2.
// dict.get('key', default=None) mean the size of args_spec_list is 3.
// dict.get('key', default_value=None) mean the size of args_spec_list is 2 too, the key will check in dict_get.
constexpr int subscript_args_size = 2;
constexpr int dict_get_arg_size = 3;
if (args_spec_list.size() != subscript_args_size && args_spec_list.size() != dict_get_arg_size) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size << " or "
<< dict_get_arg_size << ", but got " << args_spec_list.size();
if (args_spec_list.size() != subscript_args_size) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size
<< ", but got " << args_spec_list.size();
}
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
@ -214,13 +212,9 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
if (it == dict_elems.end()) {
// For dict[key], if key is not exist, will raise a KeyError exception.
if (args_spec_list.size() == subscript_args_size) {
MS_EXCEPTION(KeyError) << "The key " << key_str
<< " does not exist in the dict:" << args_spec_list[0]->ToString();
}
// For dict.get('key', default=None), if key is not exist, will return the default value.
constexpr int default_value_index = 2;
return args_spec_list[default_value_index];
// For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
MS_EXCEPTION(KeyError) << "The key " << key_value->ToString()
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
}
return it->second;
}

View File

@ -93,7 +93,7 @@ SYNTAX_UNSUPPORTED_NAMESPACE = 4 # Unsupported namespace
# Process expr statement white list
# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
parse_expr_statement_white_list = (
"append", "insert", "clear", "reverse", "extend",
"append", "insert", "clear", "reverse", "extend", "update",
)
_builtin_function_or_method_type = type(abs)

View File

@ -30,6 +30,8 @@ from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like, repeat_elements
from ...ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
_count, _extend, _dict_clear, _haskey, _update, _fromkeys
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations.math_ops import Median
@ -2823,7 +2825,29 @@ def list_count(self_, value):
def dict_get(self_, key_index, default_value=None):
"""Get value by key from dict"""
return F.dict_getitem(self_, key_index, default_value)
if not _haskey(self_, key_index):
return default_value
return F.dict_getitem(self_, key_index)
def dict_clear(self_):
"""Clear the dict"""
return _dict_clear(self_)
def dict_haskey(self_, key_index):
"""Check if key is in dict"""
return _haskey(self_, key_index)
def dict_update(self_, dict_obj):
"""Update the dict"""
return _update(self_, dict_obj)
def dict_fromkeys(self_, seq, value=None):
"""Check if key is in dict"""
return _fromkeys(self_, seq, value)
#################

View File

@ -27,7 +27,8 @@ from mindspore import log as logger
from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
ListClear_, ListReverse_, ListExtend_, ListCount_
ListClear_, ListReverse_, ListExtend_, ListCount_, DictClear_, DictHasKey_, DictUpdate_, \
DictFromKeys_
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
from mindspore.ops.primitive import Primitive
@ -1086,6 +1087,82 @@ class _ListCount(ListCount_):
_count = _ListCount("count")
class _DictClear(DictClear_):
"""
A metafuncgraph class that clear the dict.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _DictClear."""
DictClear_.__init__(self, name)
def __call__(self, *args):
pass
_dict_clear = _DictClear("clear")
class _DictHasKey(DictHasKey_):
"""
A metafuncgraph class that Check if key is in dict.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _DictHasKey."""
DictHasKey_.__init__(self, name)
def __call__(self, *args):
pass
_haskey = _DictHasKey("has_key")
class _DictUpdate(DictUpdate_):
"""
A metafuncgraph class that append another dict to the end of the dict.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _DictUpdate."""
DictUpdate_.__init__(self, name)
def __call__(self, *args):
pass
_update = _DictUpdate("update")
class _DictFromKeys(DictFromKeys_):
"""
A metafuncgraph class that creates a new dict from the given sequence and value.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _DictFromKeys."""
DictFromKeys_.__init__(self, name)
def __call__(self, *args):
pass
_fromkeys = _DictFromKeys("fromkeys")
class _Tail(Tail_):
"""
A metafuncgraph class that generates tail elements of the tuple.

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dict_clear """
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
def test_dict_clear_1():
"""
Feature: dict clear.
Description: support dict clear.
Expectation: No exception.
"""
@ms_function
def dict_net_1():
x = {'a': 1, 'b': 2}
x.clear()
return x
out = dict_net_1()
assert dict(out) == {}
def test_dict_clear_2():
"""
Feature: dict clear.
Description: support dict clear.
Expectation: No exception.
"""
@ms_function
def dict_net_2():
x = {'a': [1, 2, 'aa'], 'b': 2, 'c': Tensor(1)}
x.clear()
return x
out = dict_net_2()
assert dict(out) == {}
def test_dict_clear_3():
"""
Feature: dict clear.
Description: support dict clear.
Expectation: No exception.
"""
@ms_function
def dict_net_3():
x = {}
x.clear()
return x
out = dict_net_3()
assert dict(out) == {}

View File

@ -0,0 +1,217 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dict_fromkeys """
import ast
import pytest
from mindspore import ms_function, context
context.set_context(mode=context.GRAPH_MODE)
def test_dict_fromkeys_1():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_1():
x = {'a': 1, 'b': 2}
y = ['1', '2', '3']
new_dict = x.fromkeys(y)
return str(new_dict)
out = dict_net_1()
assert ast.literal_eval(out) == {'1': None, '2': None, '3': None}
def test_dict_fromkeys_2():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_2():
x = {'a': 1, 'b': 2}
y = ('1', '2', '3')
new_dict = x.fromkeys(y)
return str(new_dict)
out = dict_net_2()
assert ast.literal_eval(out) == {'1': None, '2': None, '3': None}
def test_dict_fromkeys_3():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_3():
x = {'a': 1, 'b': 2}
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
new_dict = x.fromkeys(y.keys())
return str(new_dict)
out = dict_net_3()
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
def test_dict_fromkeys_4():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_4():
x = {'a': 1, 'b': 2}
y = ['1', '2', "3"]
new_dict = x.fromkeys(y, 123)
return str(new_dict)
out = dict_net_4()
assert ast.literal_eval(out) == {'1': 123, '2': 123, '3': 123}
def test_dict_fromkeys_5():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_5():
x = {'a': 1, 'b': 2}
y = ('1', '2', '3')
new_dict = x.fromkeys(y, 123)
return str(new_dict)
out = dict_net_5()
assert ast.literal_eval(out) == {'1': 123, '2': 123, '3': 123}
def test_dict_fromkeys_6():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_6():
x = {'a': 1, 'b': 2}
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
new_dict = x.fromkeys(y.keys(), 123)
return str(new_dict)
out = dict_net_6()
assert ast.literal_eval(out) == {'a': 123, 'b': 123, 'c': 123, 'd': 123}
def test_dict_fromkeys_7():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_7():
x = {'a': 1, 'b': 2}
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
new_dict = x.fromkeys(y, 123)
return str(new_dict)
out = dict_net_7()
assert ast.literal_eval(out) == {'a': 123, 'b': 123, 'c': 123, 'd': 123}
def test_dict_fromkeys_8():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_8():
x = {'a': 1, 'b': 2}
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
new_dict = x.fromkeys(y)
return str(new_dict)
out = dict_net_8()
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
def test_dict_fromkeys_9():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_9():
x = {'a': 1, 'b': 2}
y = "abcd"
new_dict = x.fromkeys(y)
return str(new_dict)
out = dict_net_9()
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
def test_dict_fromkeys_10():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_10():
x = {'a': 1, 'b': 2}
y = "abcd"
new_dict = x.fromkeys(y, 111)
return str(new_dict)
out = dict_net_10()
assert ast.literal_eval(out) == {'a': 111, 'b': 111, 'c': 111, 'd': 111}
def test_dict_fromkeys_11():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_11():
x = {'a': 1, 'b': 2}
y = 123
new_dict = x.fromkeys(y, 111)
return str(new_dict)
with pytest.raises(RuntimeError):
out = dict_net_11()
print(out)
def test_dict_fromkeys_12():
"""
Feature: dict fromkeys.
Description: support dict fromkeys.
Expectation: No exception.
"""
@ms_function
def dict_net_12():
x = {'a': 1, 'b': 2}
y = ['b', 1, 'c']
new_dict = x.fromkeys(y, 111)
return str(new_dict)
with pytest.raises(RuntimeError):
out = dict_net_12()
print(out)

View File

@ -126,3 +126,21 @@ def test_dict_get_7():
return Tensor(the_value)
out = dict_net_7()
assert (out.asnumpy() == (3, 4)).all()
def test_dict_get_8():
"""
Feature: dict get.
Description: support dict get set default value.
Expectation: No exception.
"""
@ms_function
def dict_net_8(x, y, z):
dict_x = {"1": x, "2": y}
default_value = dict_x.get("3", z)
return default_value
input_x = Tensor(1)
input_y = Tensor(2)
input_z = Tensor(3)
out = dict_net_8(input_x, input_y, input_z)
assert out == 3

View File

@ -0,0 +1,79 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dict_has_key """
from mindspore import ms_function, context
context.set_context(mode=context.GRAPH_MODE)
def test_dict_haskey_1():
"""
Feature: dict has_key.
Description: support dict has_key.
Expectation: No exception.
"""
@ms_function
def dict_net_1():
x = {'a': 1, 'b': 2}
res = x.has_key('a')
return res
out = dict_net_1()
assert out is True
def test_dict_haskey_2():
"""
Feature: dict has_key.
Description: support dict has_key.
Expectation: No exception.
"""
@ms_function
def dict_net_2():
x = {'a': [2, 3, "123"], 'b': 2}
res = x.has_key('a')
return res
out = dict_net_2()
assert out is True
def test_dict_haskey_3():
"""
Feature: dict has_key.
Description: support dict has_key.
Expectation: No exception.
"""
@ms_function
def dict_net_3():
x = {'a': 1, 'b': 2}
res = x.has_key('c')
return res
out = dict_net_3()
assert out is False
def test_dict_haskey_4():
"""
Feature: dict has_key.
Description: support dict has_key.
Expectation: No exception.
"""
@ms_function
def dict_net_4():
x = {"a": 1, "b": 2, "cd": 3, "c": 4}
res = x.has_key('c')
return res
out = dict_net_4()
assert out is True

View File

@ -0,0 +1,85 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dict_update """
import ast
from mindspore import ms_function, context
context.set_context(mode=context.GRAPH_MODE)
def test_dict_update_1():
"""
Feature: dict update.
Description: support dict update.
Expectation: No exception.
"""
@ms_function
def dict_net_1():
x = {'a': 1, 'b': 2}
y = {'c': 3, 'b': 4}
x.update(y)
return str(x)
out = dict_net_1()
assert ast.literal_eval(out) == {'a': 1, 'b': 4, 'c': 3}
def test_dict_update_2():
"""
Feature: dict update.
Description: support dict update.
Expectation: No exception.
"""
@ms_function
def dict_net_2():
x = {'a': 1, 'b': 2, 'aa': 11, 'bb': 22}
y = {'dd': {'ab': 12}, 'c': 3, 'b': "aaaa", 'ddd': [1, 2, 3]}
x.update(y)
return str(x)
out = dict_net_2()
assert ast.literal_eval(out) == {'a': 1, 'b': 'aaaa', 'aa': 11, 'bb': 22, 'dd': {'ab': 12},
'c': 3, 'ddd': [1, 2, 3]}
def test_dict_update_3():
"""
Feature: dict update.
Description: support dict update.
Expectation: No exception.
"""
@ms_function
def dict_net_3():
x = {'a': 1, 'b': 2}
y = {'c': 3}
x.update(y)
return str(x)
out = dict_net_3()
assert ast.literal_eval(out) == {'a': 1, 'b': 2, 'c': 3}
def test_dict_update_4():
"""
Feature: dict update.
Description: support dict update.
Expectation: No exception.
"""
@ms_function
def dict_net_4():
x = {'a': ["aa", "bb"], 'b': 2}
y = {'c': 3, "a": {"sub": "test"}}
x.update(y)
return str(x)
out = dict_net_4()
assert ast.literal_eval(out) == {'a': {"sub": "test"}, 'b': 2, 'c': 3}