!2720 fix assign used in while loop

Merge pull request !2720 from xychow/fix-assign-in-while
This commit is contained in:
mindspore-ci-bot 2020-07-01 09:16:13 +08:00 committed by Gitee
commit ea475637a1
5 changed files with 70 additions and 22 deletions

View File

@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input
int count = 0;
int max_depth = 5;
const int max_depth = 5;
while (!is_quant_cnode(x)) {
if (count >= max_depth) {
break;

View File

@ -451,6 +451,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
}
auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape);
@ -830,6 +835,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false;
}
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
}
std::string AbstractRef::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("

View File

@ -578,6 +578,7 @@ class AbstractRef : public AbstractBase {
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
}

View File

@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list;
}
@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list;

View File

@ -16,9 +16,13 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
import mindspore.ops.operations as op
def test_net_infer():
""" test_net_infer """
class Net(nn.Cell):
""" Net definition """
@ -36,9 +40,30 @@ class Net(nn.Cell):
x = self.flatten(x)
out = self.fc(x)
return out
def test_net_infer():
""" test_net_infer """
Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net()
def test_assign_in_while():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
super().__init__()
self.assign = op.Assign()
self.inputdata = Parameter(initializer(1, input_shape), name="global_step")
def construct(self, x, y, z):
out = z
while x < y:
inputdata = self.inputdata
x = x + 1
out = self.assign(inputdata, z)
return out
x = Tensor(np.array(1).astype(np.int32))
y = Tensor(np.array(3).astype(np.int32))
input_shape = (1024, 512)
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
ret = net(x, y, z)
assert ret == z