!2809 support Python built-in function 'enumerate'
Merge pull request !2809 from zhangbuxue/support_Python_built-in_functions_enumerate
This commit is contained in:
commit
aef002ad6c
|
@ -116,6 +116,7 @@ convert_object_map = {
|
|||
T.partial: F.partial,
|
||||
T.zip: C.zip_operation,
|
||||
T.print: F.print_,
|
||||
T.enumerate: M.enumerate_,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
|
|
|
@ -104,6 +104,15 @@ def bool_(x):
|
|||
return x.__bool__()
|
||||
|
||||
|
||||
def enumerate_(x, start=0):
|
||||
"""Enumerate list or tuple."""
|
||||
x_type = F.typeof(x)
|
||||
ret = ()
|
||||
if check_is_tuple_or_list(x_type, "enumerate"):
|
||||
ret = zip(range(start, start + len(x)), x)
|
||||
return ret
|
||||
|
||||
|
||||
def while_cond(x):
|
||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||
|
@ -113,6 +122,13 @@ def while_cond(x):
|
|||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tuple_or_list(x, op_name):
|
||||
"""check whether x is list or tuple."""
|
||||
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
|
||||
return True
|
||||
raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.")
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor_bool_cond(shp):
|
||||
"""check if tensor is a bool condition"""
|
||||
|
|
|
@ -27,7 +27,7 @@ from operator import ( # noqa
|
|||
|
||||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print',
|
||||
'partial', 'print', 'enumerate',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
|
|||
}
|
||||
|
||||
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||
if (arg_pairs.size() < 1) {
|
||||
if (arg_pairs.empty()) {
|
||||
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
|
||||
}
|
||||
bool found = false;
|
||||
|
|
|
@ -18,44 +18,44 @@
|
|||
|
||||
#include "operator/composite/zip_operation.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/static_analysis/dshape.h"
|
||||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "operator/cc_implementations.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "./common.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractSequeue;
|
||||
using mindspore::abstract::AbstractSequeuePtr;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
|
||||
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// zip operation:
|
||||
// input: tuple arguments
|
||||
// output: tuple of items of input iterated on every input
|
||||
if (args_spec_list.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For 'zip', there is at least one input.";
|
||||
}
|
||||
|
||||
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return abs->isa<AbstractTuple>();
|
||||
});
|
||||
if (!is_all_tuple) {
|
||||
MS_LOG(EXCEPTION) << "zip input args should be tuple";
|
||||
auto is_all_sequeue =
|
||||
std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return abs->isa<AbstractSequeue>();
|
||||
});
|
||||
if (!is_all_sequeue) {
|
||||
MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence.";
|
||||
}
|
||||
|
||||
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
|
||||
[](const AbstractBasePtr &x, const AbstractBasePtr &y) {
|
||||
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
|
||||
});
|
||||
auto min_abs = std::min_element(
|
||||
args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) {
|
||||
return (x->cast<AbstractSequeuePtr>()->size() < y->cast<AbstractSequeuePtr>()->size());
|
||||
});
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
|
||||
|
@ -65,12 +65,14 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
|
|||
// generate tuple output of ziped arguments input
|
||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractTuplePtr>()->size(); idx++) {
|
||||
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractSequeuePtr>()->size(); idx++) {
|
||||
std::vector<AnfNodePtr> make_tuple_zip_nodes;
|
||||
make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
|
||||
ValuePtr op = prim::GetPythonOps("getitem", module_name);
|
||||
for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) {
|
||||
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(prim::kPrimTupleGetItem),
|
||||
ret_graph->parameters()[arg_idx], NewValueNode(SizeToInt(idx))};
|
||||
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx],
|
||||
NewValueNode(SizeToInt(idx))};
|
||||
auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes);
|
||||
make_tuple_zip_nodes.push_back(tuple_get_item_op);
|
||||
}
|
||||
|
|
|
@ -229,6 +229,7 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
// Inputs: x, t
|
||||
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// statement: isconstant(x)
|
||||
|
|
|
@ -1048,11 +1048,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
|||
CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
|
||||
CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
|
||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
|
||||
target_app->debug_info()->set_name(name_id);
|
||||
|
||||
CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
|
||||
body_block->WriteVariable(name_id, target_app);
|
||||
WriteAssignVars(body_block, target_node, target_app);
|
||||
|
||||
// link the variable name with the target
|
||||
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
|
||||
iter_param->debug_info()->set_trace_info(it_info);
|
||||
|
|
|
@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
|
|||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
||||
|
|
|
@ -226,11 +226,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
||||
} else {
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
||||
}
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(
|
||||
input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()));
|
||||
}
|
||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||
op_exec_info->abstract = infer_res;
|
||||
|
@ -512,7 +509,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
return result;
|
||||
}
|
||||
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *const status) {
|
||||
MS_EXCEPTION_IF_NULL(status);
|
||||
py::object result;
|
||||
|
@ -550,7 +547,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
|
|||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
||||
if (!grad_flag_ || graph_info_map_.size() == 0) {
|
||||
if (!grad_flag_ || graph_info_map_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
@ -753,7 +750,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
|
|||
if (py::isinstance<py::none>(name_attr)) {
|
||||
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||
}
|
||||
std::string param_name = py::cast<std::string>(name_attr);
|
||||
auto param_name = py::cast<std::string>(name_attr);
|
||||
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
||||
auto free_param = df_builder_->add_parameter();
|
||||
free_param->set_name(param_name);
|
||||
|
|
|
@ -97,6 +97,8 @@ tensor_type = typing.TensorType
|
|||
anything_type = typing.TypeAnything
|
||||
slice_type = typing.Slice
|
||||
ellipsis_type = typing.TypeEllipsis
|
||||
list_type = typing.List
|
||||
tuple_type = typing.Tuple
|
||||
|
||||
number_type = (int8,
|
||||
int16,
|
||||
|
|
|
@ -65,9 +65,9 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|||
tuple_len = len(tuple_index)
|
||||
for i in range(tuple_len):
|
||||
if i in int_positions:
|
||||
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
||||
tuple_index_new += (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
||||
else:
|
||||
tuple_index_new = tuple_index_new + (tuple_index[i],)
|
||||
tuple_index_new += (tuple_index[i],)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||
tensor_positions, slice_positions, ellipsis_position = \
|
||||
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
||||
|
|
|
@ -1466,7 +1466,7 @@ class Concat(PrimitiveWithInfer):
|
|||
def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
||||
"""for pack output shape"""
|
||||
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
||||
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
|
||||
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
|
||||
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
|
||||
rank_base = len(x_shape[0])
|
||||
N = len(x_shape)
|
||||
|
|
|
@ -1761,6 +1761,10 @@ test_case_array_ops = [
|
|||
'desc_inputs': [[128, 128], [128, 128]],
|
||||
'desc_bprop': [[2, 128, 128]],
|
||||
}),
|
||||
('Pack_3', {
|
||||
'block': NetForPackInput(P.Pack()),
|
||||
'desc_inputs': [[2, 2]],
|
||||
'desc_bprop': [[1, 2, 2]]}),
|
||||
('Unpack_0', {
|
||||
'block': NetForUnpackInput(P.Unpack(axis=0)),
|
||||
'desc_inputs': [[2, 4]],
|
||||
|
@ -2226,10 +2230,6 @@ raise_set = [
|
|||
Tensor(np.ones((2, 2), np.float32)),
|
||||
Tensor(np.ones((2,), np.float32))),
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('Pack', {
|
||||
'block': (NetForPackInput(P.Pack()), {'exception': ValueError}),
|
||||
'desc_inputs': [[2, 2]],
|
||||
'desc_bprop': [[1, 2, 2]]}),
|
||||
('PReLU', {
|
||||
'block': (P.PReLU(), {'exception': ValueError}),
|
||||
'desc_inputs': [[2], [1]],
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test enumerate"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_enumerate_list_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self):
|
||||
index_sum = 0
|
||||
value_sum = 0
|
||||
for i, j in enumerate(self.value):
|
||||
index_sum += i
|
||||
value_sum += j
|
||||
return index_sum, value_sum
|
||||
|
||||
net = Net()
|
||||
assert net() == (6, 110)
|
||||
|
||||
|
||||
def test_enumerate_tuple_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = (11, 22, 33, 44)
|
||||
|
||||
def construct(self):
|
||||
index_sum = 0
|
||||
value_sum = 0
|
||||
for i, j in enumerate(self.value):
|
||||
index_sum += i
|
||||
value_sum += j
|
||||
return index_sum, value_sum
|
||||
|
||||
net = Net()
|
||||
assert net() == (6, 110)
|
||||
|
||||
|
||||
def test_enumerate_list_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
index_sum = 0
|
||||
value = [x, y, z]
|
||||
ret = ()
|
||||
for i, j in enumerate(value):
|
||||
index_sum += i
|
||||
ret += (j,)
|
||||
return index_sum, ret
|
||||
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||
net = Net()
|
||||
net(x, x, x)
|
||||
|
||||
|
||||
def test_enumerate_tuple_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
index_sum = 0
|
||||
value = (x, y, z)
|
||||
ret = ()
|
||||
for i, j in enumerate(value):
|
||||
index_sum += i
|
||||
ret += (j,)
|
||||
return index_sum, ret
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||
net = Net()
|
||||
net(x, x, x)
|
||||
|
||||
|
||||
def test_enumerate_tuple_const_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = (11, 22, 33, 44)
|
||||
|
||||
def construct(self):
|
||||
index_sum = 0
|
||||
value_sum = 0
|
||||
for i in enumerate(self.value):
|
||||
index_sum += i[0]
|
||||
value_sum += i[1]
|
||||
return index_sum, value_sum
|
||||
|
||||
net = Net()
|
||||
assert net() == (6, 110)
|
||||
|
||||
|
||||
def test_enumerate_tuple_parameter_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
index_sum = 0
|
||||
value = (x, y, z)
|
||||
ret = ()
|
||||
for i in enumerate(value):
|
||||
index_sum += i[0]
|
||||
ret += (i[1],)
|
||||
return index_sum, ret
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||
net = Net()
|
||||
net(x, x, x)
|
||||
|
||||
def test_enumerate_tuple_const_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = (11, 22, 33, 44)
|
||||
|
||||
def construct(self):
|
||||
index_sum = 0
|
||||
value_sum = 0
|
||||
for i in enumerate(self.value, 1):
|
||||
index_sum += i[0]
|
||||
value_sum += i[1]
|
||||
return index_sum, value_sum
|
||||
|
||||
net = Net()
|
||||
assert net() == (10, 110)
|
||||
|
||||
|
||||
def test_enumerate_tuple_parameter_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
index_sum = 0
|
||||
value = (x, y, z)
|
||||
ret = ()
|
||||
for i in enumerate(value, 2):
|
||||
index_sum += i[0]
|
||||
ret += (i[1],)
|
||||
return index_sum, ret
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||
net = Net()
|
||||
net(x, x, x)
|
||||
|
||||
|
||||
def test_enumerate_parameter_type_error():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return enumerate(x)
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net(x)
|
||||
assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value)
|
Loading…
Reference in New Issue