From 7602054acd61ae1a99a257ca9822c7fd13251c61 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 6 Aug 2020 19:41:11 +0800 Subject: [PATCH] Fix do concat in while loop specialize error --- .../pipeline/jit/static_analysis/evaluator.cc | 16 +++++++++++++ mindspore/core/ir/func_graph.h | 1 + mindspore/core/ir/func_graph_cloner.cc | 1 + tests/ut/python/ops/test_control_ops.py | 24 +++++++++++++++++++ 4 files changed, 42 insertions(+) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index b07bb270f19..424a057bc36 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -144,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); + if (func_graph_->joined_shapes_.size() != broaded_list.size()) { + MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() + << " does not equal to number of original buffer arguments " + << func_graph_->joined_shapes_.size(); + } + for (size_t i = 0; i < broaded_list.size(); ++i) { + broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + } MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; @@ -171,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. if (!(joined_args_spec_list == args_spec_list)) { func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + func_graph_->joined_shapes_.clear(); + std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } return joined_args_spec_list; @@ -185,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa if (!(joined_args_spec_list == args_spec_list)) { trace_.push_back(joined_args_spec_list); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + func_graph_->joined_shapes_.clear(); + std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index ba1da220a74..3ce74cfb5ba 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -332,6 +332,7 @@ class FuncGraph : public FuncGraphBase { std::unordered_map &make_ref_params() { return make_ref_params_; } std::unordered_map attrs_; + std::vector joined_shapes_; std::unordered_map transforms_; // parameter default value std::map parameter_default_value_; diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index bb8ed7bc4c4..0e6b73201bf 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -220,6 +220,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); *target_func_graph = std::make_shared(); (*target_func_graph)->set_attrs(func_graph->attrs()); + (*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_; (*target_func_graph)->set_transforms(func_graph->transforms()); (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 753c4856a31..369fe5f9b1d 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -645,3 +645,27 @@ def test_mixed_precision_cast(): x = Tensor(np.ones([2, 3], dtype=np.float32)) z = F.mixed_precision_cast(mstype.float16, x) assert z.dtype == mstype.float16 + + +def test_while_concat(): + class Net(nn.Cell): + def __init__(self, data): + super(Net, self).__init__() + self.start = Tensor(0, dtype=mstype.int32) + self.end = Tensor(2, dtype=mstype.int32) + self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) + self.concat = P.Concat() + + def construct(self, inputs): + idx = self.start + end = self.end + out = self.out + while idx < end: + xi = inputs[idx, :, :] + out = self.concat((out, xi)) + idx = idx + 1 + return out + + x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) + net = Net(x) + net(x)