forked from mindspore-Ecosystem/mindspore
!49656 Eliminate dict_getitem and mutable value dict
Merge pull request !49656 from YuJianfeng/dict_grad
This commit is contained in:
commit
64e12b46e2
|
@ -105,6 +105,7 @@ class DictGetitemConstEliminator : public AnfVisitor {
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
Reset();
|
Reset();
|
||||||
AnfVisitor::Match(prim::kPrimDictGetItem, {IsVNode, IsVNode})(node);
|
AnfVisitor::Match(prim::kPrimDictGetItem, {IsVNode, IsVNode})(node);
|
||||||
|
AnfVisitor::Match(prim::kPrimDictGetItem, {IsCNode, IsVNode})(node);
|
||||||
|
|
||||||
if (real_value_ != nullptr) {
|
if (real_value_ != nullptr) {
|
||||||
auto out = NewValueNode(real_value_);
|
auto out = NewValueNode(real_value_);
|
||||||
|
@ -130,6 +131,16 @@ class DictGetitemConstEliminator : public AnfVisitor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Visit(const CNodePtr &cnode) override {
|
||||||
|
if (!IsPrimitiveCNode(cnode, prim::kPrimMutable)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto dict_input = GetValueNode<ValueDictionaryPtr>(cnode->input(1));
|
||||||
|
if (dict_input != nullptr) {
|
||||||
|
dict_ = dict_input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
real_value_ = nullptr;
|
real_value_ = nullptr;
|
||||||
dict_ = nullptr;
|
dict_ = nullptr;
|
||||||
|
|
|
@ -412,9 +412,10 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
|
||||||
assert compare(output, expect)
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="randomly failed.")
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_grad_const_dict_tensor_to_mutable():
|
def test_grad_const_dict_tensor_to_mutable():
|
||||||
"""
|
"""
|
||||||
|
@ -477,8 +478,10 @@ def test_grad_const_dict_tensor_to_mutable():
|
||||||
assert compare(output['b'], expect[1])
|
assert compare(output['b'], expect[1])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_grad_const_dict_tensor_arg_to_mutable():
|
def test_grad_const_dict_tensor_arg_to_mutable():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue