Fix tuple in tuple in call node.

This commit is contained in:
gaoyong10 2022-01-14 09:56:45 +08:00
parent ef5e5c2aae
commit 2045fc3be2
2 changed files with 18 additions and 14 deletions

View File

@ -51,6 +51,13 @@ void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
ControlActor::FetchInput(context);
size_t index = GetIndex(context);
if (!output_partial_arrows_.empty()) {
if (index + kSwitchCondPos >= input_partials_.size()) {
string error_info = "Given index " + std::to_string(index) +
" out of range. Please make sure the value of index in [" +
std::to_string(1 - SizeToInt(input_partials_.size())) + ", " +
std::to_string(input_partials_.size() - 1) + "), and the type is int32.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
MS_EXCEPTION_IF_NULL(input_partials_[index + kSwitchCondPos]);
auto func_graph = input_partials_[index + kSwitchCondPos]->func_graph_;
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -512,21 +512,18 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
}
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
if (!abstract->isa<abstract::AbstractTuple>()) {
for (size_t i = 0; i < output_num; ++i) {
(void)results.emplace_back(real_node, i);
}
return results;
for (size_t i = 0; i < output_num; ++i) {
(void)results.emplace_back(real_node, i);
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
size_t index = 0;
for (const auto &sub_abstract : sub_abstracts) {
MS_EXCEPTION_IF_NULL(sub_abstract);
if (!sub_abstract->isa<abstract::AbstractMonad>()) {
(void)results.emplace_back(real_node, index++);
if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
for (const auto &sub_abstract : sub_abstracts) {
MS_EXCEPTION_IF_NULL(sub_abstract);
if (sub_abstract->isa<abstract::AbstractMonad>()) {
(void)results.pop_back();
}
}
}
return results;