forked from mindspore-Ecosystem/mindspore
resolve pynative operator issue
This commit is contained in:
parent
5ed799d7b2
commit
75fec82b52
|
@ -125,7 +125,7 @@ def list_len(x):
|
|||
return len(x)
|
||||
|
||||
|
||||
# only used in PyNative modes
|
||||
# only used in PyNative mode
|
||||
def partial(*args):
|
||||
"""Implement `partial`."""
|
||||
func = args[0].__call__
|
||||
|
@ -133,10 +133,14 @@ def partial(*args):
|
|||
return partial_func
|
||||
|
||||
|
||||
# only used in PyNative modes
|
||||
# only used in PyNative mode
|
||||
def depend(value, expr):
|
||||
return value
|
||||
|
||||
# only used in PyNative mode
|
||||
def make_ref(key, value, ref):
|
||||
return value
|
||||
|
||||
|
||||
def scalar_cast(x, t):
|
||||
"""Implement scalar_cast."""
|
||||
|
|
|
@ -616,17 +616,19 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
|
|||
return ExecDFGraph(info_, args, phase_s);
|
||||
}
|
||||
#else
|
||||
if (backend == "ge") {
|
||||
std::shared_ptr<py::object> ret_val = std::make_shared<py::object>();
|
||||
if (backend == "ms" || backend == "ge") {
|
||||
auto ret_val = std::make_shared<py::object>();
|
||||
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
|
||||
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
|
||||
return *ret_val;
|
||||
}
|
||||
}
|
||||
if (args.size() > 0) {
|
||||
return args[0];
|
||||
if (backend == "ge") {
|
||||
if (args.size() > 0) {
|
||||
return args[0];
|
||||
}
|
||||
return args;
|
||||
}
|
||||
return args;
|
||||
}
|
||||
#endif
|
||||
std::size_t full_arg_size = ArgListSize(phase_s);
|
||||
|
|
|
@ -20,11 +20,13 @@
|
|||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils/any.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
#include "operator/composite/do_signature.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "session/session_factory.h"
|
||||
|
@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
|
|||
return converted_ret;
|
||||
}
|
||||
|
||||
py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) {
|
||||
auto signature = prim->signatures();
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature& sig) { return sig.dtype; });
|
||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||
return py_args;
|
||||
}
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
auto it = type_indexs.find(dtypes[i]);
|
||||
if (it == type_indexs.end()) {
|
||||
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
||||
} else {
|
||||
it->second.push_back(i);
|
||||
}
|
||||
}
|
||||
std::map<SignatureEnumDType, size_t> dst_type;
|
||||
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
||||
auto type = it->first;
|
||||
auto indexs = it->second;
|
||||
if (indexs.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
size_t m_index = indexs[0];
|
||||
for (size_t i = 1; i < indexs.size(); ++i) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
|
||||
m_index = indexs[i];
|
||||
}
|
||||
}
|
||||
(void)dst_type.insert(std::make_pair(type, m_index));
|
||||
}
|
||||
py::tuple py_inputs(py_args.size());
|
||||
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it != dst_type.end() && it->second != i &&
|
||||
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
|
||||
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
|
||||
if (py::isinstance<py::int_>(py_args[i])) {
|
||||
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
|
||||
} else {
|
||||
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
py_inputs[i] = py_args[i];
|
||||
}
|
||||
return py_inputs;
|
||||
}
|
||||
|
||||
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
|
||||
size_t size = py_args.size();
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
|
|||
auto op_exec_info = std::make_shared<OpExecInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
|
||||
if (py::isinstance<py::none>(args[PY_PRIM])) {
|
||||
py::module ops_mod = py::module::import("mindspore.ops.operations");
|
||||
py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())();
|
||||
op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive);
|
||||
py::dict none_attrs = py::dict();
|
||||
op_exec_info->op_attrs = none_attrs;
|
||||
} else {
|
||||
PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto pyobj = prim->GetPyObj();
|
||||
if (pyobj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||
}
|
||||
py::tuple py_args = args[PY_INPUTS];
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, py_args, op_exec_info.get());
|
||||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto pyobj = prim->GetPyObj();
|
||||
if (pyobj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||
}
|
||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, py_args, op_exec_info.get());
|
||||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
op_exec_info->op_inputs = py_args;
|
||||
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||
MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask";
|
||||
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||
return nullptr;
|
||||
}
|
||||
return op_exec_info;
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Parameter for cell."""
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
import numpy as np
|
||||
from .initializer import initializer
|
||||
from .tensor import Tensor
|
||||
|
@ -156,16 +156,24 @@ class Parameter:
|
|||
return self.default_input
|
||||
|
||||
def __add__(self, other):
|
||||
return self.default_input + other
|
||||
res = deepcopy(self)
|
||||
res.default_input = res.default_input + other
|
||||
return res
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.default_input - other
|
||||
res = deepcopy(self)
|
||||
res.default_input = res.default_input - other
|
||||
return res
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.default_input * other
|
||||
res = deepcopy(self)
|
||||
res.default_input = res.default_input * other
|
||||
return res
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.default_input / other
|
||||
res = deepcopy(self)
|
||||
res.default_input = res.default_input / other
|
||||
return res
|
||||
|
||||
def set_parameter_data(self, data):
|
||||
if isinstance(data, (Tensor, list, int, float,
|
||||
|
|
|
@ -70,45 +70,60 @@ class Tensor(Tensor_):
|
|||
return str(self.__str__())
|
||||
|
||||
def __add__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
raise TypeError("input_data must be a tensor")
|
||||
check_type('tensor input_data', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__add__')(self, other)
|
||||
return out
|
||||
|
||||
def __mul__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
raise TypeError("input_data must be a tensor")
|
||||
check_type('tensor input_data', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||
return out
|
||||
|
||||
def __neg__(self):
|
||||
return Tensor(-self.asnumpy())
|
||||
|
||||
def __iadd__(self, other):
|
||||
out = self.__add__(other)
|
||||
return out
|
||||
|
||||
def __radd__(self, other):
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__add__')(other, self)
|
||||
return out
|
||||
|
||||
def __imul__(self, other):
|
||||
out = self.__mul__(other)
|
||||
return out
|
||||
|
||||
def __rmul__(self, other):
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__mul__')(other, self)
|
||||
return out
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, (int, float)):
|
||||
other_tensor = Tensor(other, self.dtype())
|
||||
elif isinstance(other, Tensor):
|
||||
other_tensor = other
|
||||
else:
|
||||
raise TypeError("unsupported type for div operation")
|
||||
out = tensor_operator_registry.get('__div__')(self, other_tensor)
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__div__')(self, other)
|
||||
return out
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__div__')(other, self)
|
||||
return out
|
||||
|
||||
def __sub__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
raise TypeError("input_data must be a tensor")
|
||||
out = self.__add__(Tensor(-other.asnumpy()))
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = self.__add__(-other)
|
||||
return out
|
||||
|
||||
def __isub__(self, other):
|
||||
out = self.__sub__(other)
|
||||
return out
|
||||
|
||||
def __rsub__(self, other):
|
||||
check_type('tensor operation input', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
|
||||
return out
|
||||
|
||||
def __str__(self):
|
||||
if self.dtype() == mstype.type_none:
|
||||
return "Unknown Tensor type!"
|
||||
|
|
|
@ -191,7 +191,7 @@ def get_bprop_concat(self):
|
|||
|
||||
def bprop(x, out, dout):
|
||||
dx = ()
|
||||
out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x)
|
||||
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
|
||||
for i in range(F.tuple_len(x)):
|
||||
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
|
||||
dx = dx + (slice_out,)
|
||||
|
|
|
@ -14,6 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ops utils."""
|
||||
from .broadcast import _get_broadcast_shape
|
||||
from .utils import _get_broadcast_shape, _get_concat_offset
|
||||
|
||||
__all__ = ['_get_broadcast_shape']
|
||||
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
|
||||
|
|
|
@ -13,8 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""broadcast"""
|
||||
"""utils for operator"""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
||||
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||
"""
|
||||
|
@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
||||
return broadcast_shape
|
||||
|
||||
|
||||
def _get_concat_offset(x_shp, x_type, axis):
|
||||
"""for concat and concatoffset check args and compute offset"""
|
||||
validator.check_type("shape", x_shp, [tuple])
|
||||
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
||||
rank_base = len(x_shp[0])
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base
|
||||
all_shp = x_shp[0][axis]
|
||||
offset = [0,]
|
||||
for i in range(1, len(x_shp)):
|
||||
v = x_shp[i]
|
||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
||||
for j in range(rank_base):
|
||||
if j != axis and v[j] != x_shp[0][j]:
|
||||
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
||||
offset.append(all_shp)
|
||||
all_shp += v[axis]
|
||||
return offset, all_shp, axis
|
|
@ -19,7 +19,7 @@ Primitive operator classes.
|
|||
A collection of operators to build nerual networks or computing functions.
|
||||
"""
|
||||
|
||||
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack,
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
|
@ -200,7 +200,6 @@ __all__ = [
|
|||
'LogicalOr',
|
||||
'Size',
|
||||
'DepthwiseConv2dNative',
|
||||
'ConcatOffset',
|
||||
'UnsortedSegmentSum',
|
||||
"AllGather",
|
||||
"AllReduce",
|
||||
|
|
|
@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind
|
|||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel, check_int_positive, check_bool
|
||||
from .._utils import _get_concat_offset
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
||||
|
@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
|
|||
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
|
||||
return x_type
|
||||
|
||||
class ConcatOffset(PrimitiveWithInfer):
|
||||
"""primitive for computing Concat's gradient."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, N=2, axis=0):
|
||||
"""init ConcatOffset"""
|
||||
|
||||
def __infer__(self, input_x):
|
||||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
offset_values = []
|
||||
for i in range(len(x_shp)):
|
||||
values = []
|
||||
for j in range(len(x_shp[0])):
|
||||
value = 0
|
||||
if j == axis:
|
||||
value = offset[i]
|
||||
values.append(value)
|
||||
offset_values.append(tuple(values))
|
||||
out = {'shape': None,
|
||||
'dtype': None,
|
||||
'value': tuple(offset_values)}
|
||||
return out
|
||||
|
||||
|
||||
class Conv2DBackpropFilter(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -29,6 +29,7 @@ from ..._checkparam import Rel
|
|||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from .._utils import _get_concat_offset
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims):
|
||||
|
@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
def _get_concat_offset(x_shp, x_type, axis):
|
||||
"""for concat and concatoffset check args and compute offset"""
|
||||
validator.check_type("shape", x_shp, [tuple])
|
||||
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
||||
rank_base = len(x_shp[0])
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base
|
||||
all_shp = x_shp[0][axis]
|
||||
offset = [0,]
|
||||
for i in range(1, len(x_shp)):
|
||||
v = x_shp[i]
|
||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
||||
for j in range(rank_base):
|
||||
if j != axis and v[j] != x_shp[0][j]:
|
||||
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
||||
offset.append(all_shp)
|
||||
all_shp += v[axis]
|
||||
return offset, all_shp, axis
|
||||
|
||||
|
||||
class Concat(PrimitiveWithInfer):
|
||||
r"""
|
||||
Concat tensor in specified axis.
|
||||
|
@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer):
|
|||
'value': None}
|
||||
|
||||
|
||||
class ConcatOffset(PrimitiveWithInfer):
|
||||
"""primitive for computing Concat's gradient."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, N=2, axis=0):
|
||||
"""init ConcatOffset"""
|
||||
|
||||
def __infer__(self, input_x):
|
||||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
offset_values = []
|
||||
for i in range(len(x_shp)):
|
||||
values = []
|
||||
for j in range(len(x_shp[0])):
|
||||
value = 0
|
||||
if j == axis:
|
||||
value = offset[i]
|
||||
values.append(value)
|
||||
offset_values.append(tuple(values))
|
||||
out = {'shape': None,
|
||||
'dtype': None,
|
||||
'value': tuple(offset_values)}
|
||||
return out
|
||||
|
||||
|
||||
class Select(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
||||
|
|
|
@ -271,3 +271,6 @@ class MakeRefKey(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, tag):
|
||||
validator.check_type('tag', tag, (str,))
|
||||
|
||||
def __call__(self):
|
||||
pass
|
||||
|
|
|
@ -24,6 +24,7 @@ import pytest
|
|||
import mindspore as ms
|
||||
import mindspore.common.api as me
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool():
|
|||
input = ms.Tensor(input)
|
||||
input_me = ms.Tensor(input, dtype=ms.bool_)
|
||||
|
||||
|
||||
def test_tensor_operation():
|
||||
x = Tensor(np.ones((3,3)) * 4)
|
||||
res = x + 1
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
|
||||
res = 1 + x
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
|
||||
res = x - 2
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
res = 6 - x
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
res = x * 3
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
|
||||
res = 3 * x
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
|
||||
res = x / 2
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
res = 8 / x
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
with pytest.raises(TypeError):
|
||||
res = x * (2, 3)
|
||||
|
|
|
@ -190,7 +190,7 @@ def vm_impl_slice(self):
|
|||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.ConcatOffset)
|
||||
@vm_impl_getters.register(P._grad_ops.ConcatOffset)
|
||||
def vm_impl_concatOffset(self):
|
||||
"""Generate vm_impl function for ConcatOffset"""
|
||||
def vm_impl(x):
|
||||
|
|
Loading…
Reference in New Issue