forked from mindspore-Ecosystem/mindspore
!4060 fix do concat in while loop specialize error
Merge pull request !4060 from fary86/fix_while_concat_specialize_error
This commit is contained in:
commit
470328eeaf
|
@ -144,6 +144,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
MS_EXCEPTION_IF_NULL(arg);
|
||||
return arg->Broaden();
|
||||
});
|
||||
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
|
||||
<< " does not equal to number of original buffer arguments "
|
||||
<< func_graph_->joined_shapes_.size();
|
||||
}
|
||||
for (size_t i = 0; i < broaded_list.size(); ++i) {
|
||||
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
|
||||
}
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
|
||||
<< ", broaded: " << mindspore::ToString(broaded_list);
|
||||
return broaded_list;
|
||||
|
@ -171,6 +179,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_),
|
||||
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
return joined_args_spec_list;
|
||||
|
@ -185,6 +197,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
trace_.push_back(joined_args_spec_list);
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_),
|
||||
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
|
|
|
@ -332,6 +332,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
|
||||
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::vector<BaseShapePtr> joined_shapes_;
|
||||
std::unordered_map<std::string, FuncGraphTransform> transforms_;
|
||||
// parameter default value
|
||||
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
||||
|
|
|
@ -220,6 +220,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
|
||||
*target_func_graph = std::make_shared<FuncGraph>();
|
||||
(*target_func_graph)->set_attrs(func_graph->attrs());
|
||||
(*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_;
|
||||
(*target_func_graph)->set_transforms(func_graph->transforms());
|
||||
(*target_func_graph)->set_has_vararg(func_graph->has_vararg());
|
||||
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
|
||||
|
|
|
@ -645,3 +645,27 @@ def test_mixed_precision_cast():
|
|||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
z = F.mixed_precision_cast(mstype.float16, x)
|
||||
assert z.dtype == mstype.float16
|
||||
|
||||
|
||||
def test_while_concat():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, data):
|
||||
super(Net, self).__init__()
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.end = Tensor(2, dtype=mstype.int32)
|
||||
self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
|
||||
self.concat = P.Concat()
|
||||
|
||||
def construct(self, inputs):
|
||||
idx = self.start
|
||||
end = self.end
|
||||
out = self.out
|
||||
while idx < end:
|
||||
xi = inputs[idx, :, :]
|
||||
out = self.concat((out, xi))
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
|
||||
net = Net(x)
|
||||
net(x)
|
||||
|
|
Loading…
Reference in New Issue