!40351 Modifiy load infer.

Merge pull request !40351 from Margaret_wangrui/vmap_auto_monad
This commit is contained in:
i-robot 2022-08-16 02:07:09 +00:00 committed by Gitee
commit f410f67e67
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 32 additions and 7 deletions

View File

@ -222,7 +222,8 @@ class OrderEnforcer {
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs[i]; auto &input = inputs[i];
// Skip non-ref input and update_state. // Skip non-ref input and update_state.
if (!IsRef(input) || input == update_state) { // Skip Load(param, umonad) --> Ref
if (!IsRef(input) || input == update_state || IsPrimitiveCNode(input, prim::kPrimLoad)) {
continue; continue;
} }
// The input is a ref (of parameter), find load nodes for it. // The input is a ref (of parameter), find load nodes for it.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2021 Huawei Technologies Co., Ltd * Copyright 2019-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -437,11 +437,6 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref/Tensor, universal // Inputs: Ref/Tensor, universal
CheckArgsSize(primitive->name(), args_spec_list, 2); CheckArgsSize(primitive->name(), args_spec_list, 2);
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(args_spec_list[0]);
if (ref_abs != nullptr) {
// Return tensor value if input is Ref.
return ref_abs->CloneAsTensor();
}
return args_spec_list[0]->Broaden(); return args_spec_list[0]->Broaden();
} }

View File

@ -228,3 +228,32 @@ def test_load_eliminate():
net = Net() net = Net()
out = net(x) out = net(x)
assert out == 5 assert out == 5
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_tuple_assign():
"""
Feature: Auto monad feature.
Description: Parameter tuple assign.
Expectation: No exception.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.param1 = Parameter(Tensor(0), name="param1")
self.param2 = Parameter(Tensor(0), name="param2")
def construct(self, x):
params = (self.param1, self.param2)
self.assign(params[0], x)
return params[0], params[1]
x = Tensor(2)
net = Net()
out = net(x)
assert out[0] == 2
assert out[1] == 0