!2720 fix assign used in while loop

Merge pull request !2720 from xychow/fix-assign-in-while
This commit is contained in:
mindspore-ci-bot 2020-07-01 09:16:13 +08:00 committed by Gitee
commit ea475637a1
5 changed files with 70 additions and 22 deletions

View File

@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
auto weight_name = weight_node->cast<ParameterPtr>()->name(); auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input // find the fakequant from input
int count = 0; int count = 0;
int max_depth = 5; const int max_depth = 5;
while (!is_quant_cnode(x)) { while (!is_quant_cnode(x)) {
if (count >= max_depth) { if (count >= max_depth) {
break; break;

View File

@ -451,6 +451,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other_tensor == nullptr) { if (other_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
} }
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
}
auto element = element_->Join(other_tensor->element_); auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape()); auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape); auto ret = std::make_shared<AbstractTensor>(element, shape);
@ -830,6 +835,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false; return false;
} }
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
}
std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("

View File

@ -578,6 +578,7 @@ class AbstractRef : public AbstractBase {
AbstractBasePtr Broaden() const override { AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
} }

View File

@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
return joined_args_spec_list; return joined_args_spec_list;
} }
@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list); trace_.push_back(joined_args_spec_list);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list; return joined_args_spec_list;

View File

@ -16,29 +16,54 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor, context
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
class Net(nn.Cell): import mindspore.ops.operations as op
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
self.bn = nn.BatchNorm2d(64)
self.fc = nn.Dense(64, 10)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.flatten(x)
out = self.fc(x)
return out
def test_net_infer(): def test_net_infer():
""" test_net_infer """ """ test_net_infer """
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
self.bn = nn.BatchNorm2d(64)
self.fc = nn.Dense(64, 10)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.flatten(x)
out = self.fc(x)
return out
Tensor(np.random.randint(0, 255, [1, 3, 224, 224])) Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net() Net()
def test_assign_in_while():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
super().__init__()
self.assign = op.Assign()
self.inputdata = Parameter(initializer(1, input_shape), name="global_step")
def construct(self, x, y, z):
out = z
while x < y:
inputdata = self.inputdata
x = x + 1
out = self.assign(inputdata, z)
return out
x = Tensor(np.array(1).astype(np.int32))
y = Tensor(np.array(3).astype(np.int32))
input_shape = (1024, 512)
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
ret = net(x, y, z)
assert ret == z