!49656 Eliminate dict_getitem and mutable value dict

Merge pull request !49656 from YuJianfeng/dict_grad
This commit is contained in:
i-robot 2023-03-03 09:33:44 +00:00 committed by Gitee
commit 64e12b46e2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 18 additions and 4 deletions

View File

@ -105,6 +105,7 @@ class DictGetitemConstEliminator : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimDictGetItem, {IsVNode, IsVNode})(node); AnfVisitor::Match(prim::kPrimDictGetItem, {IsVNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimDictGetItem, {IsCNode, IsVNode})(node);
if (real_value_ != nullptr) { if (real_value_ != nullptr) {
auto out = NewValueNode(real_value_); auto out = NewValueNode(real_value_);
@ -130,6 +131,16 @@ class DictGetitemConstEliminator : public AnfVisitor {
} }
} }
void Visit(const CNodePtr &cnode) override {
if (!IsPrimitiveCNode(cnode, prim::kPrimMutable)) {
return;
}
auto dict_input = GetValueNode<ValueDictionaryPtr>(cnode->input(1));
if (dict_input != nullptr) {
dict_ = dict_input;
}
}
void Reset() { void Reset() {
real_value_ = nullptr; real_value_ = nullptr;
dict_ = nullptr; dict_ = nullptr;

View File

@ -412,9 +412,10 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
assert compare(output, expect) assert compare(output, expect)
@pytest.mark.skip(reason="randomly failed.")
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_grad_const_dict_tensor_to_mutable(): def test_grad_const_dict_tensor_to_mutable():
""" """
@ -477,8 +478,10 @@ def test_grad_const_dict_tensor_to_mutable():
assert compare(output['b'], expect[1]) assert compare(output['b'], expect[1])
@pytest.mark.level1 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_grad_const_dict_tensor_arg_to_mutable(): def test_grad_const_dict_tensor_arg_to_mutable():
""" """