Fix VectorRefToPyData

This commit is contained in:
huangbingjian 2022-12-23 11:23:58 +08:00
parent cc14cbd5be
commit b698409a03
2 changed files with 62 additions and 15 deletions

View File

@ -452,28 +452,42 @@ py::object VectorToPyData(const Any &value) {
}
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
// Current VectorRef reflects a COOTensor type
if (abs != nullptr && abs->isa<abstract::AbstractCSRTensor>()) {
return MakeCSRTensor(value_list);
}
if (abs != nullptr && abs->isa<abstract::AbstractCOOTensor>()) {
return MakeCOOTensor(value_list);
}
py::object ret;
size_t value_size = value_list.size();
auto ref_tuple = py::tuple(value_size);
auto seq_abs = CheckAbstractElementsSize<abstract::AbstractSequencePtr>(abs, value_size);
if (seq_abs == nullptr) {
if (abs == nullptr) {
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i]);
}
} else {
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i], seq_abs->elements()[i]);
}
ret = ref_tuple;
return ret;
}
return ref_tuple;
// Current VectorRef reflects a COOTensor type
if (abs->isa<abstract::AbstractCSRTensor>()) {
return MakeCSRTensor(value_list);
}
if (abs->isa<abstract::AbstractCOOTensor>()) {
return MakeCOOTensor(value_list);
}
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
size_t ref_idx = 0;
for (size_t i = 0; i < seq_abs->size(); i++) {
auto elem_abs = seq_abs->elements()[i];
if (elem_abs->isa<abstract::AbstractNone>()) {
continue;
}
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
ref_idx++;
}
if (ref_idx != value_size) {
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
<< ref_idx;
}
ret = ref_tuple;
return ret;
}
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,

View File

@ -0,0 +1,33 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost outputs"""
import mindspore as ms
def test_none_in_outputs():
"""
Feature: Return outputs with None.
Description: The outermost network output has None.
Expectation: No exception.
"""
@ms.jit
def func(x, y):
return None, x + y, None
x = ms.Tensor(1)
y = ms.Tensor(2)
out = func(x, y)
assert out[0].asnumpy() == ms.Tensor(3).asnumpy()