!2809 support Python built-in function 'enumerate'

Merge pull request !2809 from zhangbuxue/support_Python_built-in_functions_enumerate
This commit is contained in:
mindspore-ci-bot 2020-07-06 09:38:47 +08:00 committed by Gitee
commit aef002ad6c
14 changed files with 240 additions and 40 deletions

View File

@ -116,6 +116,7 @@ convert_object_map = {
T.partial: F.partial,
T.zip: C.zip_operation,
T.print: F.print_,
T.enumerate: M.enumerate_,
# custom define operation
T.iter: M.ms_iter,

View File

@ -104,6 +104,15 @@ def bool_(x):
return x.__bool__()
def enumerate_(x, start=0):
"""Enumerate list or tuple."""
x_type = F.typeof(x)
ret = ()
if check_is_tuple_or_list(x_type, "enumerate"):
ret = zip(range(start, start + len(x)), x)
return ret
def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
@ -113,6 +122,13 @@ def while_cond(x):
return x
@constexpr
def check_is_tuple_or_list(x, op_name):
"""check whether x is list or tuple."""
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
return True
raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.")
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""

View File

@ -27,7 +27,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate
)
# support functools
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print',
'partial', 'print', 'enumerate',
'exp', 'log', 'sin', 'cos', 'tan']

View File

@ -181,7 +181,7 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
}
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
if (arg_pairs.size() < 1) {
if (arg_pairs.empty()) {
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
}
bool found = false;

View File

@ -18,44 +18,44 @@
#include "operator/composite/zip_operation.h"
#include <algorithm>
#include <utility>
#include "pipeline/static_analysis/abstract_value.h"
#include "ir/anf.h"
#include "pipeline/static_analysis/dshape.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "utils/symbolic.h"
#include "./common.h"
#include "pybind_api/api_register.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractSequeue;
using mindspore::abstract::AbstractSequeuePtr;
using mindspore::abstract::AbstractTuple;
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// zip operation:
// input: tuple arguments
// output: tuple of items of input iterated on every input
if (args_spec_list.size() == 0) {
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "For 'zip', there is at least one input.";
}
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<AbstractTuple>();
});
if (!is_all_tuple) {
MS_LOG(EXCEPTION) << "zip input args should be tuple";
auto is_all_sequeue =
std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<AbstractSequeue>();
});
if (!is_all_sequeue) {
MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence.";
}
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
[](const AbstractBasePtr &x, const AbstractBasePtr &y) {
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
});
auto min_abs = std::min_element(
args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) {
return (x->cast<AbstractSequeuePtr>()->size() < y->cast<AbstractSequeuePtr>()->size());
});
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
@ -65,12 +65,14 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
// generate tuple output of ziped arguments input
std::vector<AnfNodePtr> make_tuple_nodes;
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractTuplePtr>()->size(); idx++) {
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractSequeuePtr>()->size(); idx++) {
std::vector<AnfNodePtr> make_tuple_zip_nodes;
make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
ValuePtr op = prim::GetPythonOps("getitem", module_name);
for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) {
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(prim::kPrimTupleGetItem),
ret_graph->parameters()[arg_idx], NewValueNode(SizeToInt(idx))};
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx],
NewValueNode(SizeToInt(idx))};
auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes);
make_tuple_zip_nodes.push_back(tuple_get_item_op);
}

View File

@ -229,6 +229,7 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
// Inputs: x, t
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
}
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// statement: isconstant(x)

View File

@ -1048,11 +1048,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
target_app->debug_info()->set_name(name_id);
CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
body_block->WriteVariable(name_id, target_app);
WriteAssignVars(body_block, target_node, target_app);
// link the variable name with the target
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
iter_param->debug_info()->set_trace_info(it_info);

View File

@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T>
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {

View File

@ -226,11 +226,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]);
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
} else {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
}
args_spec_list.emplace_back(abstract::FromValueInside(
input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()));
}
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
op_exec_info->abstract = infer_res;
@ -512,7 +509,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
return result;
}
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
PynativeStatusCode *const status) {
MS_EXCEPTION_IF_NULL(status);
py::object result;
@ -550,7 +547,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.size() == 0) {
if (!grad_flag_ || graph_info_map_.empty()) {
return nullptr;
}
std::vector<AnfNodePtr> inputs;
@ -753,7 +750,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
std::string param_name = py::cast<std::string>(name_attr);
auto param_name = py::cast<std::string>(name_attr);
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);

View File

@ -97,6 +97,8 @@ tensor_type = typing.TensorType
anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.TypeEllipsis
list_type = typing.List
tuple_type = typing.Tuple
number_type = (int8,
int16,

View File

@ -65,9 +65,9 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
tuple_len = len(tuple_index)
for i in range(tuple_len):
if i in int_positions:
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
tuple_index_new += (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
else:
tuple_index_new = tuple_index_new + (tuple_index[i],)
tuple_index_new += (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)

View File

@ -1466,7 +1466,7 @@ class Concat(PrimitiveWithInfer):
def _get_pack_shape(x_shape, x_type, axis, prim_name):
"""for pack output shape"""
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)

View File

@ -1761,6 +1761,10 @@ test_case_array_ops = [
'desc_inputs': [[128, 128], [128, 128]],
'desc_bprop': [[2, 128, 128]],
}),
('Pack_3', {
'block': NetForPackInput(P.Pack()),
'desc_inputs': [[2, 2]],
'desc_bprop': [[1, 2, 2]]}),
('Unpack_0', {
'block': NetForUnpackInput(P.Unpack(axis=0)),
'desc_inputs': [[2, 4]],
@ -2226,10 +2230,6 @@ raise_set = [
Tensor(np.ones((2, 2), np.float32)),
Tensor(np.ones((2,), np.float32))),
'desc_bprop': [[2, 3]]}),
('Pack', {
'block': (NetForPackInput(P.Pack()), {'exception': ValueError}),
'desc_inputs': [[2, 2]],
'desc_bprop': [[1, 2, 2]]}),
('PReLU', {
'block': (P.PReLU(), {'exception': ValueError}),
'desc_inputs': [[2], [1]],

View File

@ -0,0 +1,181 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test enumerate"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_enumerate_list_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self):
index_sum = 0
value_sum = 0
for i, j in enumerate(self.value):
index_sum += i
value_sum += j
return index_sum, value_sum
net = Net()
assert net() == (6, 110)
def test_enumerate_tuple_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = (11, 22, 33, 44)
def construct(self):
index_sum = 0
value_sum = 0
for i, j in enumerate(self.value):
index_sum += i
value_sum += j
return index_sum, value_sum
net = Net()
assert net() == (6, 110)
def test_enumerate_list_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
index_sum = 0
value = [x, y, z]
ret = ()
for i, j in enumerate(value):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_tuple_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
index_sum = 0
value = (x, y, z)
ret = ()
for i, j in enumerate(value):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_tuple_const_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = (11, 22, 33, 44)
def construct(self):
index_sum = 0
value_sum = 0
for i in enumerate(self.value):
index_sum += i[0]
value_sum += i[1]
return index_sum, value_sum
net = Net()
assert net() == (6, 110)
def test_enumerate_tuple_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
index_sum = 0
value = (x, y, z)
ret = ()
for i in enumerate(value):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_tuple_const_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = (11, 22, 33, 44)
def construct(self):
index_sum = 0
value_sum = 0
for i in enumerate(self.value, 1):
index_sum += i[0]
value_sum += i[1]
return index_sum, value_sum
net = Net()
assert net() == (10, 110)
def test_enumerate_tuple_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
index_sum = 0
value = (x, y, z)
ret = ()
for i in enumerate(value, 2):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_parameter_type_error():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return enumerate(x)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
with pytest.raises(TypeError) as ex:
net(x)
assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value)