!40351 Modifiy load infer.
Merge pull request !40351 from Margaret_wangrui/vmap_auto_monad
This commit is contained in:
commit
f410f67e67
|
@ -222,7 +222,8 @@ class OrderEnforcer {
|
|||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &input = inputs[i];
|
||||
// Skip non-ref input and update_state.
|
||||
if (!IsRef(input) || input == update_state) {
|
||||
// Skip Load(param, umonad) --> Ref
|
||||
if (!IsRef(input) || input == update_state || IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
// The input is a ref (of parameter), find load nodes for it.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -437,11 +437,6 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref/Tensor, universal
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(args_spec_list[0]);
|
||||
if (ref_abs != nullptr) {
|
||||
// Return tensor value if input is Ref.
|
||||
return ref_abs->CloneAsTensor();
|
||||
}
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
|
|
|
@ -228,3 +228,32 @@ def test_load_eliminate():
|
|||
net = Net()
|
||||
out = net(x)
|
||||
assert out == 5
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_tuple_assign():
|
||||
"""
|
||||
Feature: Auto monad feature.
|
||||
Description: Parameter tuple assign.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign = P.Assign()
|
||||
self.param1 = Parameter(Tensor(0), name="param1")
|
||||
self.param2 = Parameter(Tensor(0), name="param2")
|
||||
|
||||
def construct(self, x):
|
||||
params = (self.param1, self.param2)
|
||||
self.assign(params[0], x)
|
||||
return params[0], params[1]
|
||||
|
||||
x = Tensor(2)
|
||||
net = Net()
|
||||
out = net(x)
|
||||
assert out[0] == 2
|
||||
assert out[1] == 0
|
||||
|
|
Loading…
Reference in New Issue