forked from mindspore-Ecosystem/mindspore
!44760 Change ConvertAbstractToPython for tuple/list/dict.
Merge pull request !44760 from LiangZhibo/convert
This commit is contained in:
commit
a496a5612e
|
@ -361,10 +361,6 @@ py::object BuildValue(const ValuePtr &value_ptr) {
|
|||
|
||||
py::object AbstractTupleValueToPython(const AbstractTuple *tuple_abs) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_abs);
|
||||
auto value = tuple_abs->BuildValue();
|
||||
if (value->isa<AnyValue>()) {
|
||||
return py::none();
|
||||
}
|
||||
const auto &elements = tuple_abs->elements();
|
||||
size_t len = elements.size();
|
||||
py::tuple value_tuple(len);
|
||||
|
@ -446,12 +442,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver
|
|||
}
|
||||
dic[ATTR_SHAPE] = shape_tuple;
|
||||
dic[ATTR_DTYPE] = dtype_tuple;
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
|
||||
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
dic[ATTR_VALUE] = value_tuple;
|
||||
}
|
||||
dic[ATTR_VALUE] = value_tuple;
|
||||
|
||||
if (dyn_value) {
|
||||
dic[ATTR_MIN_VALUE] = min_value_tuple;
|
||||
|
@ -505,20 +496,12 @@ py::dict AbstractDictionaryToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
MS_EXCEPTION_IF_NULL(arg_dict->BuildValue());
|
||||
if (arg_dict->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
dic[ATTR_VALUE] = value_dict;
|
||||
}
|
||||
dic[ATTR_VALUE] = value_dict;
|
||||
return dic;
|
||||
}
|
||||
|
||||
py::object AbstractListValueToPython(const AbstractList *list_abs) {
|
||||
MS_EXCEPTION_IF_NULL(list_abs);
|
||||
auto value = list_abs->BuildValue();
|
||||
if (value->isa<AnyValue>()) {
|
||||
return py::none();
|
||||
}
|
||||
const auto &elements = list_abs->elements();
|
||||
size_t len = elements.size();
|
||||
py::list value_list(len);
|
||||
|
@ -585,12 +568,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert
|
|||
|
||||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
|
||||
if (arg_list->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
dic[ATTR_VALUE] = value_list;
|
||||
}
|
||||
dic[ATTR_VALUE] = value_list;
|
||||
|
||||
if (dyn_value) {
|
||||
dic[ATTR_MIN_VALUE] = min_value_list;
|
||||
|
|
|
@ -2570,7 +2570,7 @@ class Concat(PrimitiveWithCheck):
|
|||
def infer_value(self, input_x):
|
||||
"""Implement Concat infer value"""
|
||||
value = None
|
||||
if input_x is not None:
|
||||
if input_x is not None and None not in input_x:
|
||||
value = Tensor(np.concatenate([x.asnumpy() for x in input_x], axis=self.axis))
|
||||
return value
|
||||
|
||||
|
@ -2786,7 +2786,7 @@ class Stack(PrimitiveWithInfer):
|
|||
tuple_value = value['value']
|
||||
input_array = []
|
||||
infered_value = None
|
||||
if tuple_value is not None:
|
||||
if tuple_value is not None and None not in tuple_value:
|
||||
for item in tuple_value:
|
||||
npy_item = item.asnumpy()
|
||||
input_array.append(npy_item)
|
||||
|
|
|
@ -347,7 +347,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
|
|||
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
|
||||
if restart > size:
|
||||
restart = size
|
||||
if not is_within_graph(A):
|
||||
if not is_within_graph(b):
|
||||
x, info = GMRES(A, M, solve_method)(b, x0, tol, restart, maxiter, atol)
|
||||
else:
|
||||
x, info = GMRESV2(solve_method)(A, b, x0, tol, restart, maxiter, M, atol)
|
||||
|
@ -516,7 +516,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
|
|||
_type_check(func_name, maxiter, int, 'maxiter')
|
||||
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
|
||||
|
||||
if not is_within_graph(A):
|
||||
if not is_within_graph(b):
|
||||
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
else:
|
||||
x, info = CGv2()(A, b, x0, tol, atol, maxiter, M)
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dict_get """
|
||||
from mindspore import Tensor, jit, context, mutable
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@constexpr
|
||||
def count_none(arg):
|
||||
if arg is None:
|
||||
raise ValueError("The arg is None")
|
||||
a = 0
|
||||
for e in arg:
|
||||
if e is None:
|
||||
a += 1
|
||||
elif isinstance(e, (tuple, list)) and None in e:
|
||||
a += 1
|
||||
return a
|
||||
|
||||
|
||||
def test_constexpr_input_with_variable_element_tuple():
|
||||
"""
|
||||
Feature: constexpr with variable element tuple input.
|
||||
Description: If tuple is used as constexpr input, the variable element will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = (1, 2, x, x+1)
|
||||
return count_none(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_constexpr_input_with_variable_element_tuple_2():
|
||||
"""
|
||||
Feature: constexpr with variable element tuple input.
|
||||
Description: If tuple is used as constexpr input, the variable element will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = (1, 2, x, (x, 1, 2))
|
||||
return count_none(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_constexpr_input_with_variable_element_list():
|
||||
"""
|
||||
Feature: constexpr with variable element list input.
|
||||
Description: If list is used as constexpr input, the variable element will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = [1, 2, x, x+1]
|
||||
return count_none(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_constexpr_input_with_variable_element_list_2():
|
||||
"""
|
||||
Feature: constexpr with variable element list input.
|
||||
Description: If list is used as constexpr input, the variable element will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = [1, 2, x, [x, 1, 2]]
|
||||
return count_none(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_constexpr_input_with_mutable_list():
|
||||
"""
|
||||
Feature: constexpr with mutable list.
|
||||
Description: If mutable list is used as constexpr input, all elements will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = mutable([Tensor([1]), Tensor([2]), x])
|
||||
return count_none(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 3
|
Loading…
Reference in New Issue