Fix do concat in while loop specialize error

This commit is contained in:
fary86 2020-08-06 19:41:11 +08:00
parent ef292bb919
commit 7602054acd
4 changed files with 42 additions and 0 deletions

View File

@ -144,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden();
});
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
<< " does not equal to number of original buffer arguments "
<< func_graph_->joined_shapes_.size();
}
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
}
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list;
@ -171,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list;
@ -185,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);

View File

@ -332,6 +332,7 @@ class FuncGraph : public FuncGraphBase {
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
std::unordered_map<std::string, ValuePtr> attrs_;
std::vector<BaseShapePtr> joined_shapes_;
std::unordered_map<std::string, FuncGraphTransform> transforms_;
// parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_;

View File

@ -220,6 +220,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
*target_func_graph = std::make_shared<FuncGraph>();
(*target_func_graph)->set_attrs(func_graph->attrs());
(*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_;
(*target_func_graph)->set_transforms(func_graph->transforms());
(*target_func_graph)->set_has_vararg(func_graph->has_vararg());
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());

View File

@ -645,3 +645,27 @@ def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
assert z.dtype == mstype.float16
def test_while_concat():
class Net(nn.Cell):
def __init__(self, data):
super(Net, self).__init__()
self.start = Tensor(0, dtype=mstype.int32)
self.end = Tensor(2, dtype=mstype.int32)
self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
self.concat = P.Concat()
def construct(self, inputs):
idx = self.start
end = self.end
out = self.out
while idx < end:
xi = inputs[idx, :, :]
out = self.concat((out, xi))
idx = idx + 1
return out
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
net = Net(x)
net(x)