forked from mindspore-Ecosystem/mindspore
!286 resolve pynative operator unsupport issue
Merge pull request !286 from wangqiuliang/resolve-pynative-operator-issue
This commit is contained in:
commit
08985a1e56
|
@ -125,7 +125,7 @@ def list_len(x):
|
||||||
return len(x)
|
return len(x)
|
||||||
|
|
||||||
|
|
||||||
# only used in PyNative modes
|
# only used in PyNative mode
|
||||||
def partial(*args):
|
def partial(*args):
|
||||||
"""Implement `partial`."""
|
"""Implement `partial`."""
|
||||||
func = args[0].__call__
|
func = args[0].__call__
|
||||||
|
@ -133,10 +133,14 @@ def partial(*args):
|
||||||
return partial_func
|
return partial_func
|
||||||
|
|
||||||
|
|
||||||
# only used in PyNative modes
|
# only used in PyNative mode
|
||||||
def depend(value, expr):
|
def depend(value, expr):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
# only used in PyNative mode
|
||||||
|
def make_ref(key, value, ref):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def scalar_cast(x, t):
|
def scalar_cast(x, t):
|
||||||
"""Implement scalar_cast."""
|
"""Implement scalar_cast."""
|
||||||
|
|
|
@ -616,18 +616,20 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
|
||||||
return ExecDFGraph(info_, args, phase_s);
|
return ExecDFGraph(info_, args, phase_s);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (backend == "ge") {
|
if (backend == "ms" || backend == "ge") {
|
||||||
std::shared_ptr<py::object> ret_val = std::make_shared<py::object>();
|
auto ret_val = std::make_shared<py::object>();
|
||||||
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
|
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
|
||||||
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
|
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
|
||||||
return *ret_val;
|
return *ret_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (backend == "ge") {
|
||||||
if (args.size() > 0) {
|
if (args.size() > 0) {
|
||||||
return args[0];
|
return args[0];
|
||||||
}
|
}
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
std::size_t full_arg_size = ArgListSize(phase_s);
|
std::size_t full_arg_size = ArgListSize(phase_s);
|
||||||
if (size > full_arg_size) {
|
if (size > full_arg_size) {
|
||||||
|
|
|
@ -20,11 +20,13 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "utils/any.h"
|
#include "utils/any.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
#include "operator/ops.h"
|
#include "operator/ops.h"
|
||||||
|
#include "operator/composite/do_signature.h"
|
||||||
#include "pipeline/parse/data_converter.h"
|
#include "pipeline/parse/data_converter.h"
|
||||||
#include "pipeline/static_analysis/prim.h"
|
#include "pipeline/static_analysis/prim.h"
|
||||||
#include "session/session_factory.h"
|
#include "session/session_factory.h"
|
||||||
|
@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
|
||||||
return converted_ret;
|
return converted_ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) {
|
||||||
|
auto signature = prim->signatures();
|
||||||
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||||
|
[](const Signature& sig) { return sig.dtype; });
|
||||||
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||||
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||||
|
return py_args;
|
||||||
|
}
|
||||||
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||||
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||||
|
auto it = type_indexs.find(dtypes[i]);
|
||||||
|
if (it == type_indexs.end()) {
|
||||||
|
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
||||||
|
} else {
|
||||||
|
it->second.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::map<SignatureEnumDType, size_t> dst_type;
|
||||||
|
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
||||||
|
auto type = it->first;
|
||||||
|
auto indexs = it->second;
|
||||||
|
if (indexs.size() < 2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_t m_index = indexs[0];
|
||||||
|
for (size_t i = 1; i < indexs.size(); ++i) {
|
||||||
|
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
|
||||||
|
m_index = indexs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(void)dst_type.insert(std::make_pair(type, m_index));
|
||||||
|
}
|
||||||
|
py::tuple py_inputs(py_args.size());
|
||||||
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||||
|
auto it = dst_type.find(dtypes[i]);
|
||||||
|
if (it != dst_type.end() && it->second != i &&
|
||||||
|
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
|
||||||
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
|
||||||
|
if (py::isinstance<py::int_>(py_args[i])) {
|
||||||
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
|
||||||
|
} else {
|
||||||
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
py_inputs[i] = py_args[i];
|
||||||
|
}
|
||||||
|
return py_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
|
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
|
||||||
size_t size = py_args.size();
|
size_t size = py_args.size();
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
|
@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
|
||||||
auto op_exec_info = std::make_shared<OpExecInfo>();
|
auto op_exec_info = std::make_shared<OpExecInfo>();
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
|
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
|
||||||
if (py::isinstance<py::none>(args[PY_PRIM])) {
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||||
py::module ops_mod = py::module::import("mindspore.ops.operations");
|
|
||||||
py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())();
|
|
||||||
op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive);
|
|
||||||
py::dict none_attrs = py::dict();
|
|
||||||
op_exec_info->op_attrs = none_attrs;
|
|
||||||
} else {
|
|
||||||
PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
|
||||||
auto pyobj = prim->GetPyObj();
|
auto pyobj = prim->GetPyObj();
|
||||||
if (pyobj == nullptr) {
|
if (pyobj == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||||
}
|
}
|
||||||
py::tuple py_args = args[PY_INPUTS];
|
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
||||||
// use python infer method
|
// use python infer method
|
||||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||||
PynativeInfer(prim, py_args, op_exec_info.get());
|
PynativeInfer(prim, py_args, op_exec_info.get());
|
||||||
}
|
}
|
||||||
op_exec_info->py_primitive = prim;
|
op_exec_info->py_primitive = prim;
|
||||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||||
}
|
op_exec_info->op_inputs = py_args;
|
||||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
|
||||||
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
||||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||||
MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask";
|
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return op_exec_info;
|
return op_exec_info;
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Parameter for cell."""
|
"""Parameter for cell."""
|
||||||
from copy import copy
|
from copy import copy, deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .initializer import initializer
|
from .initializer import initializer
|
||||||
from .tensor import Tensor
|
from .tensor import Tensor
|
||||||
|
@ -156,16 +156,24 @@ class Parameter:
|
||||||
return self.default_input
|
return self.default_input
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
return self.default_input + other
|
res = deepcopy(self)
|
||||||
|
res.default_input = res.default_input + other
|
||||||
|
return res
|
||||||
|
|
||||||
def __sub__(self, other):
|
def __sub__(self, other):
|
||||||
return self.default_input - other
|
res = deepcopy(self)
|
||||||
|
res.default_input = res.default_input - other
|
||||||
|
return res
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
return self.default_input * other
|
res = deepcopy(self)
|
||||||
|
res.default_input = res.default_input * other
|
||||||
|
return res
|
||||||
|
|
||||||
def __truediv__(self, other):
|
def __truediv__(self, other):
|
||||||
return self.default_input / other
|
res = deepcopy(self)
|
||||||
|
res.default_input = res.default_input / other
|
||||||
|
return res
|
||||||
|
|
||||||
def set_parameter_data(self, data):
|
def set_parameter_data(self, data):
|
||||||
if isinstance(data, (Tensor, list, int, float,
|
if isinstance(data, (Tensor, list, int, float,
|
||||||
|
|
|
@ -70,45 +70,60 @@ class Tensor(Tensor_):
|
||||||
return str(self.__str__())
|
return str(self.__str__())
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
if not isinstance(other, Tensor):
|
check_type('tensor input_data', other, (Tensor, float, int))
|
||||||
raise TypeError("input_data must be a tensor")
|
|
||||||
out = tensor_operator_registry.get('__add__')(self, other)
|
out = tensor_operator_registry.get('__add__')(self, other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
if not isinstance(other, Tensor):
|
check_type('tensor input_data', other, (Tensor, float, int))
|
||||||
raise TypeError("input_data must be a tensor")
|
|
||||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __neg__(self):
|
||||||
|
return Tensor(-self.asnumpy())
|
||||||
|
|
||||||
def __iadd__(self, other):
|
def __iadd__(self, other):
|
||||||
out = self.__add__(other)
|
out = self.__add__(other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
|
out = tensor_operator_registry.get('__add__')(other, self)
|
||||||
|
return out
|
||||||
|
|
||||||
def __imul__(self, other):
|
def __imul__(self, other):
|
||||||
out = self.__mul__(other)
|
out = self.__mul__(other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __rmul__(self, other):
|
||||||
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
|
out = tensor_operator_registry.get('__mul__')(other, self)
|
||||||
|
return out
|
||||||
|
|
||||||
def __truediv__(self, other):
|
def __truediv__(self, other):
|
||||||
if isinstance(other, (int, float)):
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
other_tensor = Tensor(other, self.dtype())
|
out = tensor_operator_registry.get('__div__')(self, other)
|
||||||
elif isinstance(other, Tensor):
|
return out
|
||||||
other_tensor = other
|
|
||||||
else:
|
def __rtruediv__(self, other):
|
||||||
raise TypeError("unsupported type for div operation")
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
out = tensor_operator_registry.get('__div__')(self, other_tensor)
|
out = tensor_operator_registry.get('__div__')(other, self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __sub__(self, other):
|
def __sub__(self, other):
|
||||||
if not isinstance(other, Tensor):
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
raise TypeError("input_data must be a tensor")
|
out = self.__add__(-other)
|
||||||
out = self.__add__(Tensor(-other.asnumpy()))
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __isub__(self, other):
|
def __isub__(self, other):
|
||||||
out = self.__sub__(other)
|
out = self.__sub__(other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __rsub__(self, other):
|
||||||
|
check_type('tensor operation input', other, (Tensor, float, int))
|
||||||
|
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
|
||||||
|
return out
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.dtype() == mstype.type_none:
|
if self.dtype() == mstype.type_none:
|
||||||
return "Unknown Tensor type!"
|
return "Unknown Tensor type!"
|
||||||
|
|
|
@ -191,7 +191,7 @@ def get_bprop_concat(self):
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = ()
|
dx = ()
|
||||||
out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x)
|
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
|
||||||
for i in range(F.tuple_len(x)):
|
for i in range(F.tuple_len(x)):
|
||||||
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
|
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
|
||||||
dx = dx + (slice_out,)
|
dx = dx + (slice_out,)
|
||||||
|
|
|
@ -14,6 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""ops utils."""
|
"""ops utils."""
|
||||||
from .broadcast import _get_broadcast_shape
|
from .utils import _get_broadcast_shape, _get_concat_offset
|
||||||
|
|
||||||
__all__ = ['_get_broadcast_shape']
|
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
|
||||||
|
|
|
@ -13,8 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""broadcast"""
|
"""utils for operator"""
|
||||||
|
|
||||||
|
from ..._checkparam import ParamValidator as validator
|
||||||
|
from ..._checkparam import Rel
|
||||||
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
"""
|
"""
|
||||||
|
@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||||
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
||||||
return broadcast_shape
|
return broadcast_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _get_concat_offset(x_shp, x_type, axis):
|
||||||
|
"""for concat and concatoffset check args and compute offset"""
|
||||||
|
validator.check_type("shape", x_shp, [tuple])
|
||||||
|
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
||||||
|
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
||||||
|
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
||||||
|
rank_base = len(x_shp[0])
|
||||||
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
||||||
|
if axis < 0:
|
||||||
|
axis = axis + rank_base
|
||||||
|
all_shp = x_shp[0][axis]
|
||||||
|
offset = [0,]
|
||||||
|
for i in range(1, len(x_shp)):
|
||||||
|
v = x_shp[i]
|
||||||
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
||||||
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
||||||
|
for j in range(rank_base):
|
||||||
|
if j != axis and v[j] != x_shp[0][j]:
|
||||||
|
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
||||||
|
offset.append(all_shp)
|
||||||
|
all_shp += v[axis]
|
||||||
|
return offset, all_shp, axis
|
|
@ -19,7 +19,7 @@ Primitive operator classes.
|
||||||
A collection of operators to build nerual networks or computing functions.
|
A collection of operators to build nerual networks or computing functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack,
|
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||||
|
@ -200,7 +200,6 @@ __all__ = [
|
||||||
'LogicalOr',
|
'LogicalOr',
|
||||||
'Size',
|
'Size',
|
||||||
'DepthwiseConv2dNative',
|
'DepthwiseConv2dNative',
|
||||||
'ConcatOffset',
|
|
||||||
'UnsortedSegmentSum',
|
'UnsortedSegmentSum',
|
||||||
"AllGather",
|
"AllGather",
|
||||||
"AllReduce",
|
"AllReduce",
|
||||||
|
|
|
@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
from ..._checkparam import ParamValidator as validator
|
from ..._checkparam import ParamValidator as validator
|
||||||
from ..._checkparam import Rel, check_int_positive, check_bool
|
from ..._checkparam import Rel, check_int_positive, check_bool
|
||||||
|
from .._utils import _get_concat_offset
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
|
||||||
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
|
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
|
||||||
return x_type
|
return x_type
|
||||||
|
|
||||||
|
class ConcatOffset(PrimitiveWithInfer):
|
||||||
|
"""primitive for computing Concat's gradient."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, N=2, axis=0):
|
||||||
|
"""init ConcatOffset"""
|
||||||
|
|
||||||
|
def __infer__(self, input_x):
|
||||||
|
axis = self.axis
|
||||||
|
x_shp = input_x['shape']
|
||||||
|
x_type = input_x['dtype']
|
||||||
|
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
|
||||||
|
self.add_prim_attr('T', x_type[0].element_type())
|
||||||
|
offset_values = []
|
||||||
|
for i in range(len(x_shp)):
|
||||||
|
values = []
|
||||||
|
for j in range(len(x_shp[0])):
|
||||||
|
value = 0
|
||||||
|
if j == axis:
|
||||||
|
value = offset[i]
|
||||||
|
values.append(value)
|
||||||
|
offset_values.append(tuple(values))
|
||||||
|
out = {'shape': None,
|
||||||
|
'dtype': None,
|
||||||
|
'value': tuple(offset_values)}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Conv2DBackpropFilter(PrimitiveWithInfer):
|
class Conv2DBackpropFilter(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -29,6 +29,7 @@ from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
|
from .._utils import _get_concat_offset
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
|
|
||||||
def _check_infer_attr_reduce(axis, keep_dims):
|
def _check_infer_attr_reduce(axis, keep_dims):
|
||||||
|
@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _get_concat_offset(x_shp, x_type, axis):
|
|
||||||
"""for concat and concatoffset check args and compute offset"""
|
|
||||||
validator.check_type("shape", x_shp, [tuple])
|
|
||||||
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
|
||||||
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
|
||||||
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
|
||||||
rank_base = len(x_shp[0])
|
|
||||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
|
||||||
if axis < 0:
|
|
||||||
axis = axis + rank_base
|
|
||||||
all_shp = x_shp[0][axis]
|
|
||||||
offset = [0,]
|
|
||||||
for i in range(1, len(x_shp)):
|
|
||||||
v = x_shp[i]
|
|
||||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
|
||||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
|
||||||
for j in range(rank_base):
|
|
||||||
if j != axis and v[j] != x_shp[0][j]:
|
|
||||||
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
|
||||||
offset.append(all_shp)
|
|
||||||
all_shp += v[axis]
|
|
||||||
return offset, all_shp, axis
|
|
||||||
|
|
||||||
|
|
||||||
class Concat(PrimitiveWithInfer):
|
class Concat(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Concat tensor in specified axis.
|
Concat tensor in specified axis.
|
||||||
|
@ -1533,34 +1510,6 @@ class Slice(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
class ConcatOffset(PrimitiveWithInfer):
|
|
||||||
"""primitive for computing Concat's gradient."""
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self, N=2, axis=0):
|
|
||||||
"""init ConcatOffset"""
|
|
||||||
|
|
||||||
def __infer__(self, input_x):
|
|
||||||
axis = self.axis
|
|
||||||
x_shp = input_x['shape']
|
|
||||||
x_type = input_x['dtype']
|
|
||||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
|
|
||||||
self.add_prim_attr('T', x_type[0].element_type())
|
|
||||||
offset_values = []
|
|
||||||
for i in range(len(x_shp)):
|
|
||||||
values = []
|
|
||||||
for j in range(len(x_shp[0])):
|
|
||||||
value = 0
|
|
||||||
if j == axis:
|
|
||||||
value = offset[i]
|
|
||||||
values.append(value)
|
|
||||||
offset_values.append(tuple(values))
|
|
||||||
out = {'shape': None,
|
|
||||||
'dtype': None,
|
|
||||||
'value': tuple(offset_values)}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Select(PrimitiveWithInfer):
|
class Select(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
|
|
|
@ -271,3 +271,6 @@ class MakeRefKey(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, tag):
|
def __init__(self, tag):
|
||||||
validator.check_type('tag', tag, (str,))
|
validator.check_type('tag', tag, (str,))
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
pass
|
||||||
|
|
|
@ -24,6 +24,7 @@ import pytest
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.common.api as me
|
import mindspore.common.api as me
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool():
|
||||||
input = ms.Tensor(input)
|
input = ms.Tensor(input)
|
||||||
input_me = ms.Tensor(input, dtype=ms.bool_)
|
input_me = ms.Tensor(input, dtype=ms.bool_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_operation():
|
||||||
|
x = Tensor(np.ones((3,3)) * 4)
|
||||||
|
res = x + 1
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
|
||||||
|
res = 1 + x
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
|
||||||
|
res = x - 2
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||||
|
res = 6 - x
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||||
|
res = x * 3
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
|
||||||
|
res = 3 * x
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
|
||||||
|
res = x / 2
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||||
|
res = 8 / x
|
||||||
|
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
res = x * (2, 3)
|
||||||
|
|
|
@ -190,7 +190,7 @@ def vm_impl_slice(self):
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
|
|
||||||
@vm_impl_getters.register(P.ConcatOffset)
|
@vm_impl_getters.register(P._grad_ops.ConcatOffset)
|
||||||
def vm_impl_concatOffset(self):
|
def vm_impl_concatOffset(self):
|
||||||
"""Generate vm_impl function for ConcatOffset"""
|
"""Generate vm_impl function for ConcatOffset"""
|
||||||
def vm_impl(x):
|
def vm_impl(x):
|
||||||
|
|
Loading…
Reference in New Issue