!44760 Change ConvertAbstractToPython for tuple/list/dict.

Merge pull request !44760 from LiangZhibo/convert
This commit is contained in:
i-robot 2022-11-01 03:26:07 +00:00 committed by Gitee
commit a496a5612e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 114 additions and 29 deletions

View File

@ -361,10 +361,6 @@ py::object BuildValue(const ValuePtr &value_ptr) {
py::object AbstractTupleValueToPython(const AbstractTuple *tuple_abs) {
MS_EXCEPTION_IF_NULL(tuple_abs);
auto value = tuple_abs->BuildValue();
if (value->isa<AnyValue>()) {
return py::none();
}
const auto &elements = tuple_abs->elements();
size_t len = elements.size();
py::tuple value_tuple(len);
@ -446,12 +442,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver
}
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
dic[ATTR_VALUE] = value_tuple;
}
dic[ATTR_VALUE] = value_tuple;
if (dyn_value) {
dic[ATTR_MIN_VALUE] = min_value_tuple;
@ -505,20 +496,12 @@ py::dict AbstractDictionaryToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
MS_EXCEPTION_IF_NULL(arg_dict->BuildValue());
if (arg_dict->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
dic[ATTR_VALUE] = value_dict;
}
dic[ATTR_VALUE] = value_dict;
return dic;
}
py::object AbstractListValueToPython(const AbstractList *list_abs) {
MS_EXCEPTION_IF_NULL(list_abs);
auto value = list_abs->BuildValue();
if (value->isa<AnyValue>()) {
return py::none();
}
const auto &elements = list_abs->elements();
size_t len = elements.size();
py::list value_list(len);
@ -585,12 +568,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
if (arg_list->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
dic[ATTR_VALUE] = value_list;
}
dic[ATTR_VALUE] = value_list;
if (dyn_value) {
dic[ATTR_MIN_VALUE] = min_value_list;

View File

@ -2570,7 +2570,7 @@ class Concat(PrimitiveWithCheck):
def infer_value(self, input_x):
"""Implement Concat infer value"""
value = None
if input_x is not None:
if input_x is not None and None not in input_x:
value = Tensor(np.concatenate([x.asnumpy() for x in input_x], axis=self.axis))
return value
@ -2786,7 +2786,7 @@ class Stack(PrimitiveWithInfer):
tuple_value = value['value']
input_array = []
infered_value = None
if tuple_value is not None:
if tuple_value is not None and None not in tuple_value:
for item in tuple_value:
npy_item = item.asnumpy()
input_array.append(npy_item)

View File

@ -347,7 +347,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
if restart > size:
restart = size
if not is_within_graph(A):
if not is_within_graph(b):
x, info = GMRES(A, M, solve_method)(b, x0, tol, restart, maxiter, atol)
else:
x, info = GMRESV2(solve_method)(A, b, x0, tol, restart, maxiter, M, atol)
@ -516,7 +516,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
_type_check(func_name, maxiter, int, 'maxiter')
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
if not is_within_graph(A):
if not is_within_graph(b):
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
else:
x, info = CGv2()(A, b, x0, tol, atol, maxiter, M)

View File

@ -0,0 +1,107 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dict_get """
from mindspore import Tensor, jit, context, mutable
from mindspore.ops.primitive import constexpr
context.set_context(mode=context.GRAPH_MODE)
@constexpr
def count_none(arg):
if arg is None:
raise ValueError("The arg is None")
a = 0
for e in arg:
if e is None:
a += 1
elif isinstance(e, (tuple, list)) and None in e:
a += 1
return a
def test_constexpr_input_with_variable_element_tuple():
"""
Feature: constexpr with variable element tuple input.
Description: If tuple is used as constexpr input, the variable element will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = (1, 2, x, x+1)
return count_none(arg)
out = foo(Tensor([1]))
assert out == 2
def test_constexpr_input_with_variable_element_tuple_2():
"""
Feature: constexpr with variable element tuple input.
Description: If tuple is used as constexpr input, the variable element will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = (1, 2, x, (x, 1, 2))
return count_none(arg)
out = foo(Tensor([1]))
assert out == 2
def test_constexpr_input_with_variable_element_list():
"""
Feature: constexpr with variable element list input.
Description: If list is used as constexpr input, the variable element will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = [1, 2, x, x+1]
return count_none(arg)
out = foo(Tensor([1]))
assert out == 2
def test_constexpr_input_with_variable_element_list_2():
"""
Feature: constexpr with variable element list input.
Description: If list is used as constexpr input, the variable element will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = [1, 2, x, [x, 1, 2]]
return count_none(arg)
out = foo(Tensor([1]))
assert out == 2
def test_constexpr_input_with_mutable_list():
"""
Feature: constexpr with mutable list.
Description: If mutable list is used as constexpr input, all elements will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = mutable([Tensor([1]), Tensor([2]), x])
return count_none(arg)
out = foo(Tensor([1]))
assert out == 3