forked from mindspore-Ecosystem/mindspore
!2809 support Python built-in function 'enumerate'
Merge pull request !2809 from zhangbuxue/support_Python_built-in_functions_enumerate
This commit is contained in:
commit
aef002ad6c
|
@ -116,6 +116,7 @@ convert_object_map = {
|
||||||
T.partial: F.partial,
|
T.partial: F.partial,
|
||||||
T.zip: C.zip_operation,
|
T.zip: C.zip_operation,
|
||||||
T.print: F.print_,
|
T.print: F.print_,
|
||||||
|
T.enumerate: M.enumerate_,
|
||||||
|
|
||||||
# custom define operation
|
# custom define operation
|
||||||
T.iter: M.ms_iter,
|
T.iter: M.ms_iter,
|
||||||
|
|
|
@ -104,6 +104,15 @@ def bool_(x):
|
||||||
return x.__bool__()
|
return x.__bool__()
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_(x, start=0):
|
||||||
|
"""Enumerate list or tuple."""
|
||||||
|
x_type = F.typeof(x)
|
||||||
|
ret = ()
|
||||||
|
if check_is_tuple_or_list(x_type, "enumerate"):
|
||||||
|
ret = zip(range(start, start + len(x)), x)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def while_cond(x):
|
def while_cond(x):
|
||||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||||
|
@ -113,6 +122,13 @@ def while_cond(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_is_tuple_or_list(x, op_name):
|
||||||
|
"""check whether x is list or tuple."""
|
||||||
|
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
|
||||||
|
return True
|
||||||
|
raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.")
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_is_tensor_bool_cond(shp):
|
def check_is_tensor_bool_cond(shp):
|
||||||
"""check if tensor is a bool condition"""
|
"""check if tensor is a bool condition"""
|
||||||
|
|
|
@ -27,7 +27,7 @@ from operator import ( # noqa
|
||||||
|
|
||||||
# support system function call
|
# support system function call
|
||||||
from builtins import ( # noqa
|
from builtins import ( # noqa
|
||||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
|
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate
|
||||||
)
|
)
|
||||||
|
|
||||||
# support functools
|
# support functools
|
||||||
|
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
||||||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||||
'matmul', 'getitem', 'setitem',
|
'matmul', 'getitem', 'setitem',
|
||||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||||
'partial', 'print',
|
'partial', 'print', 'enumerate',
|
||||||
'exp', 'log', 'sin', 'cos', 'tan']
|
'exp', 'log', 'sin', 'cos', 'tan']
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||||
if (arg_pairs.size() < 1) {
|
if (arg_pairs.empty()) {
|
||||||
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
|
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
|
||||||
}
|
}
|
||||||
bool found = false;
|
bool found = false;
|
||||||
|
|
|
@ -18,44 +18,44 @@
|
||||||
|
|
||||||
#include "operator/composite/zip_operation.h"
|
#include "operator/composite/zip_operation.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "pipeline/static_analysis/abstract_value.h"
|
#include "pipeline/static_analysis/abstract_value.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "pipeline/static_analysis/dshape.h"
|
#include "pipeline/static_analysis/dshape.h"
|
||||||
#include "pipeline/static_analysis/param_validator.h"
|
|
||||||
#include "operator/cc_implementations.h"
|
#include "operator/cc_implementations.h"
|
||||||
#include "optimizer/opt.h"
|
#include "optimizer/opt.h"
|
||||||
#include "utils/symbolic.h"
|
|
||||||
#include "./common.h"
|
|
||||||
#include "pybind_api/api_register.h"
|
#include "pybind_api/api_register.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// namespace to support composite operators definition
|
// namespace to support composite operators definition
|
||||||
namespace prim {
|
namespace prim {
|
||||||
using mindspore::abstract::AbstractBase;
|
using mindspore::abstract::AbstractBase;
|
||||||
|
using mindspore::abstract::AbstractList;
|
||||||
|
using mindspore::abstract::AbstractSequeue;
|
||||||
|
using mindspore::abstract::AbstractSequeuePtr;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
|
|
||||||
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// zip operation:
|
// zip operation:
|
||||||
// input: tuple arguments
|
// input: tuple arguments
|
||||||
// output: tuple of items of input iterated on every input
|
// output: tuple of items of input iterated on every input
|
||||||
if (args_spec_list.size() == 0) {
|
if (args_spec_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
|
MS_LOG(EXCEPTION) << "For 'zip', there is at least one input.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
auto is_all_sequeue =
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||||
return abs->isa<AbstractTuple>();
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
});
|
return abs->isa<AbstractSequeue>();
|
||||||
if (!is_all_tuple) {
|
});
|
||||||
MS_LOG(EXCEPTION) << "zip input args should be tuple";
|
if (!is_all_sequeue) {
|
||||||
|
MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
|
auto min_abs = std::min_element(
|
||||||
[](const AbstractBasePtr &x, const AbstractBasePtr &y) {
|
args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) {
|
||||||
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
|
return (x->cast<AbstractSequeuePtr>()->size() < y->cast<AbstractSequeuePtr>()->size());
|
||||||
});
|
});
|
||||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
|
for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
|
||||||
|
@ -65,12 +65,14 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
|
||||||
// generate tuple output of ziped arguments input
|
// generate tuple output of ziped arguments input
|
||||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||||
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractTuplePtr>()->size(); idx++) {
|
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractSequeuePtr>()->size(); idx++) {
|
||||||
std::vector<AnfNodePtr> make_tuple_zip_nodes;
|
std::vector<AnfNodePtr> make_tuple_zip_nodes;
|
||||||
make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
|
||||||
|
ValuePtr op = prim::GetPythonOps("getitem", module_name);
|
||||||
for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) {
|
for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) {
|
||||||
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(prim::kPrimTupleGetItem),
|
std::vector<AnfNodePtr> tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx],
|
||||||
ret_graph->parameters()[arg_idx], NewValueNode(SizeToInt(idx))};
|
NewValueNode(SizeToInt(idx))};
|
||||||
auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes);
|
auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes);
|
||||||
make_tuple_zip_nodes.push_back(tuple_get_item_op);
|
make_tuple_zip_nodes.push_back(tuple_get_item_op);
|
||||||
}
|
}
|
||||||
|
|
|
@ -229,6 +229,7 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
// Inputs: x, t
|
// Inputs: x, t
|
||||||
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
|
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// statement: isconstant(x)
|
// statement: isconstant(x)
|
||||||
|
|
|
@ -1048,11 +1048,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
||||||
CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
|
CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
|
||||||
CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
|
CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
|
||||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
|
|
||||||
target_app->debug_info()->set_name(name_id);
|
|
||||||
|
|
||||||
CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
|
CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
|
||||||
body_block->WriteVariable(name_id, target_app);
|
WriteAssignVars(body_block, target_node, target_app);
|
||||||
|
|
||||||
// link the variable name with the target
|
// link the variable name with the target
|
||||||
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
|
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
|
||||||
iter_param->debug_info()->set_trace_info(it_info);
|
iter_param->debug_info()->set_trace_info(it_info);
|
||||||
|
|
|
@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||||
|
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
|
||||||
|
|
|
@ -226,11 +226,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||||
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
|
args_spec_list.emplace_back(abstract::FromValueInside(
|
||||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()));
|
||||||
} else {
|
|
||||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||||
op_exec_info->abstract = infer_res;
|
op_exec_info->abstract = infer_res;
|
||||||
|
@ -512,7 +509,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
|
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||||
PynativeStatusCode *const status) {
|
PynativeStatusCode *const status) {
|
||||||
MS_EXCEPTION_IF_NULL(status);
|
MS_EXCEPTION_IF_NULL(status);
|
||||||
py::object result;
|
py::object result;
|
||||||
|
@ -550,7 +547,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
||||||
if (!grad_flag_ || graph_info_map_.size() == 0) {
|
if (!grad_flag_ || graph_info_map_.empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
@ -753,7 +750,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
|
||||||
if (py::isinstance<py::none>(name_attr)) {
|
if (py::isinstance<py::none>(name_attr)) {
|
||||||
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||||
}
|
}
|
||||||
std::string param_name = py::cast<std::string>(name_attr);
|
auto param_name = py::cast<std::string>(name_attr);
|
||||||
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
||||||
auto free_param = df_builder_->add_parameter();
|
auto free_param = df_builder_->add_parameter();
|
||||||
free_param->set_name(param_name);
|
free_param->set_name(param_name);
|
||||||
|
|
|
@ -97,6 +97,8 @@ tensor_type = typing.TensorType
|
||||||
anything_type = typing.TypeAnything
|
anything_type = typing.TypeAnything
|
||||||
slice_type = typing.Slice
|
slice_type = typing.Slice
|
||||||
ellipsis_type = typing.TypeEllipsis
|
ellipsis_type = typing.TypeEllipsis
|
||||||
|
list_type = typing.List
|
||||||
|
tuple_type = typing.Tuple
|
||||||
|
|
||||||
number_type = (int8,
|
number_type = (int8,
|
||||||
int16,
|
int16,
|
||||||
|
|
|
@ -65,9 +65,9 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||||
tuple_len = len(tuple_index)
|
tuple_len = len(tuple_index)
|
||||||
for i in range(tuple_len):
|
for i in range(tuple_len):
|
||||||
if i in int_positions:
|
if i in int_positions:
|
||||||
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
||||||
else:
|
else:
|
||||||
tuple_index_new = tuple_index_new + (tuple_index[i],)
|
tuple_index_new += (tuple_index[i],)
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||||
tensor_positions, slice_positions, ellipsis_position = \
|
tensor_positions, slice_positions, ellipsis_position = \
|
||||||
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
||||||
|
|
|
@ -1466,7 +1466,7 @@ class Concat(PrimitiveWithInfer):
|
||||||
def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
||||||
"""for pack output shape"""
|
"""for pack output shape"""
|
||||||
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
||||||
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
|
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
|
||||||
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
|
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
|
||||||
rank_base = len(x_shape[0])
|
rank_base = len(x_shape[0])
|
||||||
N = len(x_shape)
|
N = len(x_shape)
|
||||||
|
|
|
@ -1761,6 +1761,10 @@ test_case_array_ops = [
|
||||||
'desc_inputs': [[128, 128], [128, 128]],
|
'desc_inputs': [[128, 128], [128, 128]],
|
||||||
'desc_bprop': [[2, 128, 128]],
|
'desc_bprop': [[2, 128, 128]],
|
||||||
}),
|
}),
|
||||||
|
('Pack_3', {
|
||||||
|
'block': NetForPackInput(P.Pack()),
|
||||||
|
'desc_inputs': [[2, 2]],
|
||||||
|
'desc_bprop': [[1, 2, 2]]}),
|
||||||
('Unpack_0', {
|
('Unpack_0', {
|
||||||
'block': NetForUnpackInput(P.Unpack(axis=0)),
|
'block': NetForUnpackInput(P.Unpack(axis=0)),
|
||||||
'desc_inputs': [[2, 4]],
|
'desc_inputs': [[2, 4]],
|
||||||
|
@ -2226,10 +2230,6 @@ raise_set = [
|
||||||
Tensor(np.ones((2, 2), np.float32)),
|
Tensor(np.ones((2, 2), np.float32)),
|
||||||
Tensor(np.ones((2,), np.float32))),
|
Tensor(np.ones((2,), np.float32))),
|
||||||
'desc_bprop': [[2, 3]]}),
|
'desc_bprop': [[2, 3]]}),
|
||||||
('Pack', {
|
|
||||||
'block': (NetForPackInput(P.Pack()), {'exception': ValueError}),
|
|
||||||
'desc_inputs': [[2, 2]],
|
|
||||||
'desc_bprop': [[1, 2, 2]]}),
|
|
||||||
('PReLU', {
|
('PReLU', {
|
||||||
'block': (P.PReLU(), {'exception': ValueError}),
|
'block': (P.PReLU(), {'exception': ValueError}),
|
||||||
'desc_inputs': [[2], [1]],
|
'desc_inputs': [[2], [1]],
|
||||||
|
|
|
@ -0,0 +1,181 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test enumerate"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_list_const():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.value = [11, 22, 33, 44]
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
index_sum = 0
|
||||||
|
value_sum = 0
|
||||||
|
for i, j in enumerate(self.value):
|
||||||
|
index_sum += i
|
||||||
|
value_sum += j
|
||||||
|
return index_sum, value_sum
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
assert net() == (6, 110)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_tuple_const():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.value = (11, 22, 33, 44)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
index_sum = 0
|
||||||
|
value_sum = 0
|
||||||
|
for i, j in enumerate(self.value):
|
||||||
|
index_sum += i
|
||||||
|
value_sum += j
|
||||||
|
return index_sum, value_sum
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
assert net() == (6, 110)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_list_parameter():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
index_sum = 0
|
||||||
|
value = [x, y, z]
|
||||||
|
ret = ()
|
||||||
|
for i, j in enumerate(value):
|
||||||
|
index_sum += i
|
||||||
|
ret += (j,)
|
||||||
|
return index_sum, ret
|
||||||
|
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||||
|
net = Net()
|
||||||
|
net(x, x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_tuple_parameter():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
index_sum = 0
|
||||||
|
value = (x, y, z)
|
||||||
|
ret = ()
|
||||||
|
for i, j in enumerate(value):
|
||||||
|
index_sum += i
|
||||||
|
ret += (j,)
|
||||||
|
return index_sum, ret
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||||
|
net = Net()
|
||||||
|
net(x, x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_tuple_const_1():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.value = (11, 22, 33, 44)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
index_sum = 0
|
||||||
|
value_sum = 0
|
||||||
|
for i in enumerate(self.value):
|
||||||
|
index_sum += i[0]
|
||||||
|
value_sum += i[1]
|
||||||
|
return index_sum, value_sum
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
assert net() == (6, 110)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_tuple_parameter_1():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
index_sum = 0
|
||||||
|
value = (x, y, z)
|
||||||
|
ret = ()
|
||||||
|
for i in enumerate(value):
|
||||||
|
index_sum += i[0]
|
||||||
|
ret += (i[1],)
|
||||||
|
return index_sum, ret
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||||
|
net = Net()
|
||||||
|
net(x, x, x)
|
||||||
|
|
||||||
|
def test_enumerate_tuple_const_2():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.value = (11, 22, 33, 44)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
index_sum = 0
|
||||||
|
value_sum = 0
|
||||||
|
for i in enumerate(self.value, 1):
|
||||||
|
index_sum += i[0]
|
||||||
|
value_sum += i[1]
|
||||||
|
return index_sum, value_sum
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
assert net() == (10, 110)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_tuple_parameter_2():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
index_sum = 0
|
||||||
|
value = (x, y, z)
|
||||||
|
ret = ()
|
||||||
|
for i in enumerate(value, 2):
|
||||||
|
index_sum += i[0]
|
||||||
|
ret += (i[1],)
|
||||||
|
return index_sum, ret
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||||
|
net = Net()
|
||||||
|
net(x, x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enumerate_parameter_type_error():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return enumerate(x)
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(TypeError) as ex:
|
||||||
|
net(x)
|
||||||
|
assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value)
|
Loading…
Reference in New Issue