!40351 Modifiy load infer.
Merge pull request !40351 from Margaret_wangrui/vmap_auto_monad
This commit is contained in:
commit
f410f67e67
|
@ -222,7 +222,8 @@ class OrderEnforcer {
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
auto &input = inputs[i];
|
auto &input = inputs[i];
|
||||||
// Skip non-ref input and update_state.
|
// Skip non-ref input and update_state.
|
||||||
if (!IsRef(input) || input == update_state) {
|
// Skip Load(param, umonad) --> Ref
|
||||||
|
if (!IsRef(input) || input == update_state || IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// The input is a ref (of parameter), find load nodes for it.
|
// The input is a ref (of parameter), find load nodes for it.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -437,11 +437,6 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: Ref/Tensor, universal
|
// Inputs: Ref/Tensor, universal
|
||||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||||
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(args_spec_list[0]);
|
|
||||||
if (ref_abs != nullptr) {
|
|
||||||
// Return tensor value if input is Ref.
|
|
||||||
return ref_abs->CloneAsTensor();
|
|
||||||
}
|
|
||||||
return args_spec_list[0]->Broaden();
|
return args_spec_list[0]->Broaden();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -228,3 +228,32 @@ def test_load_eliminate():
|
||||||
net = Net()
|
net = Net()
|
||||||
out = net(x)
|
out = net(x)
|
||||||
assert out == 5
|
assert out == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_parameter_tuple_assign():
|
||||||
|
"""
|
||||||
|
Feature: Auto monad feature.
|
||||||
|
Description: Parameter tuple assign.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.assign = P.Assign()
|
||||||
|
self.param1 = Parameter(Tensor(0), name="param1")
|
||||||
|
self.param2 = Parameter(Tensor(0), name="param2")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
params = (self.param1, self.param2)
|
||||||
|
self.assign(params[0], x)
|
||||||
|
return params[0], params[1]
|
||||||
|
|
||||||
|
x = Tensor(2)
|
||||||
|
net = Net()
|
||||||
|
out = net(x)
|
||||||
|
assert out[0] == 2
|
||||||
|
assert out[1] == 0
|
||||||
|
|
Loading…
Reference in New Issue