forked from OSSInnovation/mindspore
fix assign used in while
This commit is contained in:
parent
bc30576ac9
commit
d5255fe311
|
@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
|||
auto weight_name = weight_node->cast<ParameterPtr>()->name();
|
||||
// find the fakequant from input
|
||||
int count = 0;
|
||||
int max_depth = 5;
|
||||
const int max_depth = 5;
|
||||
while (!is_quant_cnode(x)) {
|
||||
if (count >= max_depth) {
|
||||
break;
|
||||
|
|
|
@ -429,6 +429,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
if (other_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
}
|
||||
if (*this == *other) {
|
||||
if (sparse_grad() == other->sparse_grad()) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
}
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
||||
|
@ -812,6 +817,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
|
||||
auto other_ref = other->cast<AbstractRefPtr>();
|
||||
if (other_ref == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
}
|
||||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
auto ref_key = ref_key_->Join(other_ref->ref_key_);
|
||||
auto ref = ref_->Join(other_ref->ref());
|
||||
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);
|
||||
|
||||
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
|
||||
}
|
||||
|
||||
std::string AbstractRef::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << type_name() << "("
|
||||
|
|
|
@ -564,6 +564,7 @@ class AbstractRef : public AbstractBase {
|
|||
AbstractBasePtr Broaden() const override {
|
||||
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
|
||||
}
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
std::size_t hash() const override {
|
||||
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
|
||||
}
|
||||
|
|
|
@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
return joined_args_spec_list;
|
||||
}
|
||||
|
@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
trace_.push_back(joined_args_spec_list);
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
return joined_args_spec_list;
|
||||
|
|
|
@ -16,29 +16,54 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
self.bn = nn.BatchNorm2d(64)
|
||||
self.fc = nn.Dense(64, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.relu(x)
|
||||
x = self.flatten(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
def test_net_infer():
|
||||
""" test_net_infer """
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
self.bn = nn.BatchNorm2d(64)
|
||||
self.fc = nn.Dense(64, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.relu(x)
|
||||
x = self.flatten(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
|
||||
Net()
|
||||
|
||||
|
||||
def test_assign_in_while():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input_shape):
|
||||
super().__init__()
|
||||
self.assign = op.Assign()
|
||||
self.inputdata = Parameter(initializer(1, input_shape), name="global_step")
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = z
|
||||
while x < y:
|
||||
inputdata = self.inputdata
|
||||
x = x + 1
|
||||
out = self.assign(inputdata, z)
|
||||
return out
|
||||
|
||||
x = Tensor(np.array(1).astype(np.int32))
|
||||
y = Tensor(np.array(3).astype(np.int32))
|
||||
input_shape = (1024, 512)
|
||||
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
|
||||
net = Net(input_shape)
|
||||
ret = net(x, y, z)
|
||||
assert ret == z
|
||||
|
|
Loading…
Reference in New Issue