From 75fec82b5248f656f0a97701ce32f54cc0c9be79 Mon Sep 17 00:00:00 2001 From: kingfo Date: Tue, 14 Apr 2020 11:52:56 +0800 Subject: [PATCH] resolve pynative operator issue --- mindspore/_extends/builtin_operations.py | 8 +- mindspore/ccsrc/pipeline/pipeline.cc | 12 +-- mindspore/ccsrc/pynative/pynative_execute.cc | 87 ++++++++++++++----- mindspore/common/parameter.py | 18 ++-- mindspore/common/tensor.py | 43 ++++++--- mindspore/ops/_grad/grad_array_ops.py | 2 +- mindspore/ops/_utils/__init__.py | 4 +- .../ops/_utils/{broadcast.py => utils.py} | 29 ++++++- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 28 ++++++ mindspore/ops/operations/array_ops.py | 53 +---------- mindspore/ops/operations/other_ops.py | 3 + tests/ut/python/ir/test_tensor.py | 22 +++++ tests/vm_impl/array_ops_vm_impl.py | 2 +- 14 files changed, 208 insertions(+), 106 deletions(-) rename mindspore/ops/_utils/{broadcast.py => utils.py} (62%) diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 087b7047196..6fea07425e6 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -125,7 +125,7 @@ def list_len(x): return len(x) -# only used in PyNative modes +# only used in PyNative mode def partial(*args): """Implement `partial`.""" func = args[0].__call__ @@ -133,10 +133,14 @@ def partial(*args): return partial_func -# only used in PyNative modes +# only used in PyNative mode def depend(value, expr): return value +# only used in PyNative mode +def make_ref(key, value, ref): + return value + def scalar_cast(x, t): """Implement scalar_cast.""" diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 003d4c15e99..cd4fe28db93 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -616,17 +616,19 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { return ExecDFGraph(info_, args, phase_s); } #else - if (backend == "ge") { - std::shared_ptr ret_val = std::make_shared(); + if (backend == "ms" || backend == "ge") { + auto ret_val = std::make_shared(); if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { return *ret_val; } } - if (args.size() > 0) { - return args[0]; + if (backend == "ge") { + if (args.size() > 0) { + return args[0]; + } + return args; } - return args; } #endif std::size_t full_arg_size = ArgListSize(phase_s); diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 4144ad2d6ba..5620634bcca 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -20,11 +20,13 @@ #include #include #include +#include #include "utils/any.h" #include "utils/utils.h" #include "utils/context/ms_context.h" #include "operator/ops.h" +#include "operator/composite/do_signature.h" #include "pipeline/parse/data_converter.h" #include "pipeline/static_analysis/prim.h" #include "session/session_factory.h" @@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } +py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { + auto signature = prim->signatures(); + std::vector dtypes; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), + [](const Signature& sig) { return sig.dtype; }); + int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); + if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { + return py_args; + } + std::map> type_indexs; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indexs.find(dtypes[i]); + if (it == type_indexs.end()) { + (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + std::map dst_type; + for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { + auto type = it->first; + auto indexs = it->second; + if (indexs.size() < 2) { + continue; + } + size_t m_index = indexs[0]; + for (size_t i = 1; i < indexs.size(); ++i) { + if (py::isinstance(py_args[indexs[i]])) { + m_index = indexs[i]; + } + } + (void)dst_type.insert(std::make_pair(type, m_index)); + } + py::tuple py_inputs(py_args.size()); + for (size_t i = 0; i < py_args.size(); ++i) { + auto it = dst_type.find(dtypes[i]); + if (it != dst_type.end() && it->second != i && + (py::isinstance(py_args[i]) || py::isinstance(py_args[i]))) { + auto tensor_ptr = py::cast(py_args[it->second]); + if (py::isinstance(py_args[i])) { + py_inputs[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); + } else { + py_inputs[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); + } + continue; + } + py_inputs[i] = py_args[i]; + } + return py_inputs; +} + void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { size_t size = py_args.size(); AbstractBasePtrList args_spec_list; @@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { auto op_exec_info = std::make_shared(); MS_EXCEPTION_IF_NULL(op_exec_info); op_exec_info->op_name = py::cast(args[PY_NAME]); - if (py::isinstance(args[PY_PRIM])) { - py::module ops_mod = py::module::import("mindspore.ops.operations"); - py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())(); - op_exec_info->py_primitive = py::cast(py_primitive); - py::dict none_attrs = py::dict(); - op_exec_info->op_attrs = none_attrs; - } else { - PrimitivePyPtr prim = py::cast(args[PY_PRIM]); - auto pyobj = prim->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "pyobj is empty"; - } - py::tuple py_args = args[PY_INPUTS]; - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, py_args, op_exec_info.get()); - } - op_exec_info->py_primitive = prim; - op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); + auto prim = py::cast(args[PY_PRIM]); + auto pyobj = prim->GetPyObj(); + if (pyobj == nullptr) { + MS_LOG(EXCEPTION) << "pyobj is empty"; } - op_exec_info->op_inputs = args[PY_INPUTS]; + py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]); + // use python infer method + if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, py_args, op_exec_info.get()); + } + op_exec_info->py_primitive = prim; + op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); + op_exec_info->op_inputs = py_args; op_exec_info->inputs_mask = args[PY_INPUT_MASK]; if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { - MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask"; + MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask"; return nullptr; } return op_exec_info; diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index c8ddf0eac62..c354bcd2352 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -14,7 +14,7 @@ # ============================================================================ """Parameter for cell.""" -from copy import copy +from copy import copy, deepcopy import numpy as np from .initializer import initializer from .tensor import Tensor @@ -156,16 +156,24 @@ class Parameter: return self.default_input def __add__(self, other): - return self.default_input + other + res = deepcopy(self) + res.default_input = res.default_input + other + return res def __sub__(self, other): - return self.default_input - other + res = deepcopy(self) + res.default_input = res.default_input - other + return res def __mul__(self, other): - return self.default_input * other + res = deepcopy(self) + res.default_input = res.default_input * other + return res def __truediv__(self, other): - return self.default_input / other + res = deepcopy(self) + res.default_input = res.default_input / other + return res def set_parameter_data(self, data): if isinstance(data, (Tensor, list, int, float, diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 709b2ae2805..70b8b169ca1 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -70,45 +70,60 @@ class Tensor(Tensor_): return str(self.__str__()) def __add__(self, other): - if not isinstance(other, Tensor): - raise TypeError("input_data must be a tensor") + check_type('tensor input_data', other, (Tensor, float, int)) out = tensor_operator_registry.get('__add__')(self, other) return out def __mul__(self, other): - if not isinstance(other, Tensor): - raise TypeError("input_data must be a tensor") + check_type('tensor input_data', other, (Tensor, float, int)) out = tensor_operator_registry.get('__mul__')(self, other) return out + def __neg__(self): + return Tensor(-self.asnumpy()) + def __iadd__(self, other): out = self.__add__(other) return out + def __radd__(self, other): + check_type('tensor operation input', other, (Tensor, float, int)) + out = tensor_operator_registry.get('__add__')(other, self) + return out + def __imul__(self, other): out = self.__mul__(other) return out + def __rmul__(self, other): + check_type('tensor operation input', other, (Tensor, float, int)) + out = tensor_operator_registry.get('__mul__')(other, self) + return out + def __truediv__(self, other): - if isinstance(other, (int, float)): - other_tensor = Tensor(other, self.dtype()) - elif isinstance(other, Tensor): - other_tensor = other - else: - raise TypeError("unsupported type for div operation") - out = tensor_operator_registry.get('__div__')(self, other_tensor) + check_type('tensor operation input', other, (Tensor, float, int)) + out = tensor_operator_registry.get('__div__')(self, other) + return out + + def __rtruediv__(self, other): + check_type('tensor operation input', other, (Tensor, float, int)) + out = tensor_operator_registry.get('__div__')(other, self) return out def __sub__(self, other): - if not isinstance(other, Tensor): - raise TypeError("input_data must be a tensor") - out = self.__add__(Tensor(-other.asnumpy())) + check_type('tensor operation input', other, (Tensor, float, int)) + out = self.__add__(-other) return out def __isub__(self, other): out = self.__sub__(other) return out + def __rsub__(self, other): + check_type('tensor operation input', other, (Tensor, float, int)) + out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy())) + return out + def __str__(self): if self.dtype() == mstype.type_none: return "Unknown Tensor type!" diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index abad030ae91..35d37b3ada8 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -191,7 +191,7 @@ def get_bprop_concat(self): def bprop(x, out, dout): dx = () - out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x) + out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) for i in range(F.tuple_len(x)): slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) dx = dx + (slice_out,) diff --git a/mindspore/ops/_utils/__init__.py b/mindspore/ops/_utils/__init__.py index 00ce07453a7..8fe11029684 100644 --- a/mindspore/ops/_utils/__init__.py +++ b/mindspore/ops/_utils/__init__.py @@ -14,6 +14,6 @@ # ============================================================================ """ops utils.""" -from .broadcast import _get_broadcast_shape +from .utils import _get_broadcast_shape, _get_concat_offset -__all__ = ['_get_broadcast_shape'] +__all__ = ['_get_broadcast_shape', '_get_concat_offset'] diff --git a/mindspore/ops/_utils/broadcast.py b/mindspore/ops/_utils/utils.py similarity index 62% rename from mindspore/ops/_utils/broadcast.py rename to mindspore/ops/_utils/utils.py index c71158de57b..fbd81c4f0d9 100644 --- a/mindspore/ops/_utils/broadcast.py +++ b/mindspore/ops/_utils/utils.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================ -"""broadcast""" +"""utils for operator""" +from ..._checkparam import ParamValidator as validator +from ..._checkparam import Rel +from ...common import dtype as mstype def _get_broadcast_shape(x_shape, y_shape, prim_name): """ @@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape = broadcast_shape_front + broadcast_shape_back return broadcast_shape + + +def _get_concat_offset(x_shp, x_type, axis): + """for concat and concatoffset check args and compute offset""" + validator.check_type("shape", x_shp, [tuple]) + validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) + validator.check_subclass("shape0", x_type[0], mstype.tensor) + validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) + rank_base = len(x_shp[0]) + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + if axis < 0: + axis = axis + rank_base + all_shp = x_shp[0][axis] + offset = [0,] + for i in range(1, len(x_shp)): + v = x_shp[i] + validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) + validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + for j in range(rank_base): + if j != axis and v[j] != x_shp[0][j]: + raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) + offset.append(all_shp) + all_shp += v[axis] + return offset, all_shp, axis diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 492ebae4446..e1dd8e36c5e 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -19,7 +19,7 @@ Primitive operator classes. A collection of operators to build nerual networks or computing functions. """ -from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack, +from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, @@ -200,7 +200,6 @@ __all__ = [ 'LogicalOr', 'Size', 'DepthwiseConv2dNative', - 'ConcatOffset', 'UnsortedSegmentSum', "AllGather", "AllReduce", diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6143f9d0a05..48d1a2a89ca 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..._checkparam import ParamValidator as validator from ..._checkparam import Rel, check_int_positive, check_bool +from .._utils import _get_concat_offset from ...common import dtype as mstype @@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) return x_type +class ConcatOffset(PrimitiveWithInfer): + """primitive for computing Concat's gradient.""" + + @prim_attr_register + def __init__(self, N=2, axis=0): + """init ConcatOffset""" + + def __infer__(self, input_x): + axis = self.axis + x_shp = input_x['shape'] + x_type = input_x['dtype'] + offset, _, axis = _get_concat_offset(x_shp, x_type, axis) + self.add_prim_attr('T', x_type[0].element_type()) + offset_values = [] + for i in range(len(x_shp)): + values = [] + for j in range(len(x_shp[0])): + value = 0 + if j == axis: + value = offset[i] + values.append(value) + offset_values.append(tuple(values)) + out = {'shape': None, + 'dtype': None, + 'value': tuple(offset_values)} + return out + class Conv2DBackpropFilter(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a7c3f504404..da16a2ab297 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -29,6 +29,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from ..operations.math_ops import _infer_shape_reduce +from .._utils import _get_concat_offset from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register def _check_infer_attr_reduce(axis, keep_dims): @@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer): return out -def _get_concat_offset(x_shp, x_type, axis): - """for concat and concatoffset check args and compute offset""" - validator.check_type("shape", x_shp, [tuple]) - validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) - validator.check_subclass("shape0", x_type[0], mstype.tensor) - validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) - rank_base = len(x_shp[0]) - validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) - if axis < 0: - axis = axis + rank_base - all_shp = x_shp[0][axis] - offset = [0,] - for i in range(1, len(x_shp)): - v = x_shp[i] - validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) - validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) - for j in range(rank_base): - if j != axis and v[j] != x_shp[0][j]: - raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) - offset.append(all_shp) - all_shp += v[axis] - return offset, all_shp, axis - - class Concat(PrimitiveWithInfer): r""" Concat tensor in specified axis. @@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer): 'value': None} -class ConcatOffset(PrimitiveWithInfer): - """primitive for computing Concat's gradient.""" - - @prim_attr_register - def __init__(self, N=2, axis=0): - """init ConcatOffset""" - - def __infer__(self, input_x): - axis = self.axis - x_shp = input_x['shape'] - x_type = input_x['dtype'] - offset, _, axis = _get_concat_offset(x_shp, x_type, axis) - self.add_prim_attr('T', x_type[0].element_type()) - offset_values = [] - for i in range(len(x_shp)): - values = [] - for j in range(len(x_shp[0])): - value = 0 - if j == axis: - value = offset[i] - values.append(value) - offset_values.append(tuple(values)) - out = {'shape': None, - 'dtype': None, - 'value': tuple(offset_values)} - return out - - class Select(PrimitiveWithInfer): r""" diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index b6182f0476b..ff66e809725 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -271,3 +271,6 @@ class MakeRefKey(Primitive): @prim_attr_register def __init__(self, tag): validator.check_type('tag', tag, (str,)) + + def __call__(self): + pass diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 1757567db5e..b7bf1bebf5e 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -24,6 +24,7 @@ import pytest import mindspore as ms import mindspore.common.api as me import mindspore.nn as nn +from mindspore import Tensor from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from ..ut_filter import non_graph_engine @@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool(): input = ms.Tensor(input) input_me = ms.Tensor(input, dtype=ms.bool_) + +def test_tensor_operation(): + x = Tensor(np.ones((3,3)) * 4) + res = x + 1 + assert np.all(res.asnumpy() == np.ones((3, 3)) * 5) + res = 1 + x + assert np.all(res.asnumpy() == np.ones((3, 3)) * 5) + res = x - 2 + assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) + res = 6 - x + assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) + res = x * 3 + assert np.all(res.asnumpy() == np.ones((3, 3)) * 12) + res = 3 * x + assert np.all(res.asnumpy() == np.ones((3, 3)) * 12) + res = x / 2 + assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) + res = 8 / x + assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) + with pytest.raises(TypeError): + res = x * (2, 3) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 4258dadc62f..38c613012e8 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -190,7 +190,7 @@ def vm_impl_slice(self): return vm_impl -@vm_impl_getters.register(P.ConcatOffset) +@vm_impl_getters.register(P._grad_ops.ConcatOffset) def vm_impl_concatOffset(self): """Generate vm_impl function for ConcatOffset""" def vm_impl(x):