forked from mindspore-Ecosystem/mindspore
!5645 [bug][api]updata signature
Merge pull request !5645 from vlne-v1/ref_demo
This commit is contained in:
commit
021ba724cf
|
@ -20,7 +20,6 @@ from mindspore.common.tensor import Tensor
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
||||
|
||||
def scalar_add(x, y):
|
||||
"""Implement `scalar_add`."""
|
||||
return x + y
|
||||
|
@ -117,25 +116,6 @@ def bool_or(x, y):
|
|||
return x or y
|
||||
|
||||
|
||||
def vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
if obj_str == "shape":
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return fn
|
||||
if len(args) == 2:
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return Tensor(fn())
|
||||
if isinstance(args[0], Tensor):
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
|
||||
else:
|
||||
obj_str = "__r" + obj_str[2:]
|
||||
fn = getattr(args[1].asnumpy(), obj_str)
|
||||
y = args[0]
|
||||
return Tensor(np.array(fn(y)))
|
||||
|
||||
|
||||
def make_list(*xs):
|
||||
"""Implement `make_list`."""
|
||||
return list(xs)
|
||||
|
|
|
@ -262,6 +262,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
std::set<size_t> write_indices;
|
||||
std::vector<TypePtr> input_types;
|
||||
op_inputs.push_back(NewValueNode(function));
|
||||
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
|
||||
// Assume, the write input of op is always the first input. We check if any write op,
|
||||
// and add cast op on other inputs to keep the same type with assigned parameter.
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
|
@ -280,7 +281,6 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
|
||||
TypePtr type = args_spec_list[i]->BuildType();
|
||||
if (type && type->isa<RefType>()) {
|
||||
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
|
||||
if (sig == SignatureEnumRW::kRWRead) {
|
||||
auto source_tensor_type = type->cast<TensorTypePtr>();
|
||||
if (source_tensor_type != nullptr) {
|
||||
|
@ -300,8 +300,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
|
||||
<< type->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
|
||||
<< args_spec_list[i]->ToString();
|
||||
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
|
||||
<< args_spec_list[i]->ToString() << " type " << type->ToString();
|
||||
input_types.push_back(type);
|
||||
op_inputs.push_back(param);
|
||||
}
|
||||
|
|
|
@ -305,9 +305,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||
} else if (abs_base->isa<AbstractRef>()) {
|
||||
auto value = abs_base->cast<AbstractRefPtr>()->ref();
|
||||
dic = ConvertAbstractToPython(value);
|
||||
} else if (abs_base->isa<AbstractEllipsis>()) {
|
||||
dic[ATTR_SHAPE] = py::none();
|
||||
dic[ATTR_DTYPE] = py::ellipsis();
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||
// Define python "MetaFuncGraph_" class
|
||||
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
||||
.def(py::init<std::string &>());
|
||||
.def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature.");
|
||||
// Define python "FuncGraph" class
|
||||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||
.def(py::init())
|
||||
|
|
|
@ -48,22 +48,9 @@ void SyncData(const py::object &arg) {
|
|||
}
|
||||
} // namespace
|
||||
std::map<std::string, py::object> PrimitivePy::hook_grad_;
|
||||
static ValuePtr PyArgToValue(const py::object &arg) {
|
||||
if (py::isinstance<SignatureEnumKind>(arg) &&
|
||||
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
||||
return nullptr;
|
||||
}
|
||||
return parse::data_converter::PyDataToValue(arg);
|
||||
}
|
||||
|
||||
void PrimitivePy::set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
|
||||
signatures_.clear();
|
||||
for (auto &signature : signatures) {
|
||||
auto [name, rw, kind, arg_default, dtype] = signature;
|
||||
auto default_value = PyArgToValue(arg_default);
|
||||
signatures_.emplace_back(name, rw, kind, default_value, dtype);
|
||||
}
|
||||
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
||||
signatures_ = signatures;
|
||||
set_has_signature(true);
|
||||
}
|
||||
|
||||
|
|
|
@ -42,9 +42,7 @@ class PrimitivePy : public Primitive {
|
|||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||
py::function GetBpropFunction();
|
||||
|
||||
void set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
||||
signatures);
|
||||
void set_signatures(const std::vector<Signature> &signatures);
|
||||
|
||||
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||
|
||||
|
|
|
@ -17,12 +17,26 @@
|
|||
#include "ir/signature.h"
|
||||
#include "pybind11/operators.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
static ValuePtr PyArgToValue(const py::object &arg) {
|
||||
if (py::isinstance<SignatureEnumKind>(arg) &&
|
||||
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
||||
return nullptr;
|
||||
}
|
||||
return parse::data_converter::PyDataToValue(arg);
|
||||
}
|
||||
// Bind SignatureEnumRW as a python class.
|
||||
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
|
||||
(void)py::class_<Signature>(*m, "Signature")
|
||||
.def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
|
||||
py::object arg_default, SignatureEnumDType dtype) {
|
||||
auto default_value = PyArgToValue(arg_default);
|
||||
return Signature(name, rw, kind, default_value, dtype);
|
||||
}));
|
||||
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
|
||||
.value("RW_READ", SignatureEnumRW::kRWRead)
|
||||
.value("RW_WRITE", SignatureEnumRW::kRWWrite)
|
||||
|
|
|
@ -393,3 +393,24 @@ class SparseTensor:
|
|||
@property
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
||||
|
||||
def _vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
if obj_str == "shape":
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return fn
|
||||
if len(args) == 2:
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return Tensor(fn())
|
||||
if isinstance(args[0], Tensor):
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
|
||||
else:
|
||||
obj_str = "__r" + obj_str[2:]
|
||||
fn = getattr(args[1].asnumpy(), obj_str)
|
||||
y = args[0]
|
||||
return Tensor(np.array(fn(y)))
|
||||
|
||||
tensor_operator_registry.register('vm_compare', _vm_compare)
|
||||
|
|
|
@ -34,14 +34,17 @@ from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
|||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
|
||||
from .primitive import constexpr
|
||||
from .._c_expression import signature_rw, signature_kind
|
||||
from . import composite, operations, functional
|
||||
from . import signature
|
||||
|
||||
__primitive__ = [
|
||||
"prim_attr_register", "Primitive", "PrimitiveWithInfer",
|
||||
"signature_rw", "signature_kind"
|
||||
"prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature"
|
||||
]
|
||||
|
||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
|
||||
"constexpr"]
|
||||
__all__.extend(__primitive__)
|
||||
__all__.extend(composite.__all__)
|
||||
__all__.extend(operations.__all__)
|
||||
__all__.extend(functional.__all__)
|
||||
|
|
|
@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
|
|||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
from .. import functional as F
|
||||
from ...common.parameter import Parameter
|
||||
from ...common.tensor import Tensor
|
||||
|
||||
from .. import signature as sig
|
||||
|
||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
||||
|
@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
|
||||
Args:
|
||||
name (str): Operator name.
|
||||
read_value (bool): If the registered function not need to set value on Parameter,
|
||||
and all inputs will pass by value. Set `read_value` to True. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot find matching fn for the given args.
|
||||
|
@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
>>> add = MultitypeFuncGraph('add')
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name, read_value=False):
|
||||
MultitypeFuncGraph_.__init__(self, name)
|
||||
self.entries = list()
|
||||
if read_value:
|
||||
self.set_signatures((
|
||||
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
|
||||
|
||||
def __call__(self, *args):
|
||||
def unwrap(arg):
|
||||
if isinstance(arg, Parameter):
|
||||
return arg.data
|
||||
return arg
|
||||
types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args))
|
||||
types = tuple(map(mstype.get_py_obj_dtype, args))
|
||||
for sigs, fn in self.entries:
|
||||
if len(sigs) != len(types):
|
||||
continue
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
add = base.MultitypeFuncGraph('add')
|
||||
add = base.MultitypeFuncGraph('add', True)
|
||||
"""`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator."""
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
div = base.MultitypeFuncGraph("div")
|
||||
div = base.MultitypeFuncGraph("div", True)
|
||||
"""
|
||||
div is a metafuncgraph object which will div two objects according to input type
|
||||
using ".register" decorator
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
equal = base.MultitypeFuncGraph("equal")
|
||||
equal = base.MultitypeFuncGraph("equal", True)
|
||||
"""
|
||||
equal is a metafuncgraph object which will determine if two objects are equal according to input type
|
||||
using ".register" decorator
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
floordiv = base.MultitypeFuncGraph("floordiv")
|
||||
floordiv = base.MultitypeFuncGraph("floordiv", True)
|
||||
"""
|
||||
`floordiv` is a metafuncgraph object which will compute the floordiv of two objects
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from .. import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
getitem = base.MultitypeFuncGraph('getitem')
|
||||
getitem = base.MultitypeFuncGraph('getitem', True)
|
||||
"""
|
||||
getitem is a metafuncgraph object which will get item from an object according to input type
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
|
||||
# using ".register" decorator
|
||||
greater_equal = base.MultitypeFuncGraph("greater_equal")
|
||||
greater_equal = base.MultitypeFuncGraph("greater_equal", True)
|
||||
|
||||
|
||||
@greater_equal.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
|
||||
# using ".register" decorator
|
||||
greater = base.MultitypeFuncGraph("greater")
|
||||
greater = base.MultitypeFuncGraph("greater", True)
|
||||
|
||||
|
||||
@greater.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from . import _constexpr_utils as const_utils
|
|||
from ... import functional as F
|
||||
from ...composite import base
|
||||
|
||||
in_ = base.MultitypeFuncGraph("in")
|
||||
in_ = base.MultitypeFuncGraph("in", True)
|
||||
"""
|
||||
in_ is a metafuncgraph object which will determine if a in b
|
||||
using ".register" decorator
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# less_equal is a metagraph object which will determine if two objects are less_equal according to input type
|
||||
# using ".register" decorator
|
||||
less_equal = base.MultitypeFuncGraph("less_equal")
|
||||
less_equal = base.MultitypeFuncGraph("less_equal", True)
|
||||
|
||||
|
||||
@less_equal.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# less is a metafuncgraph object which will determine if two objects are less according to input type
|
||||
# using ".register" decorator
|
||||
less = base.MultitypeFuncGraph("less")
|
||||
less = base.MultitypeFuncGraph("less", True)
|
||||
|
||||
|
||||
@less.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# logical_not is a metagraph object which will generate function according to input type
|
||||
# using ".register" decorator
|
||||
logical_not = base.MultitypeFuncGraph("logical_not")
|
||||
logical_not = base.MultitypeFuncGraph("logical_not", True)
|
||||
|
||||
|
||||
@logical_not.register("Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# logical_and is a metagraph object which will generate function according to input type
|
||||
# using ".register" decorator
|
||||
logical_and = base.MultitypeFuncGraph("logical_and")
|
||||
logical_and = base.MultitypeFuncGraph("logical_and", True)
|
||||
|
||||
|
||||
@logical_and.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
|
||||
# logical_or is a metagraph object which will generate function according to input type
|
||||
# using ".register" decorator
|
||||
logical_or = base.MultitypeFuncGraph("logical_or")
|
||||
logical_or = base.MultitypeFuncGraph("logical_or", True)
|
||||
|
||||
|
||||
@logical_or.register("Number", "Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
mod = base.MultitypeFuncGraph("mod")
|
||||
mod = base.MultitypeFuncGraph("mod", True)
|
||||
"""
|
||||
`mod` is a metafuncgraph object which will compute the mod of two objects
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
mul = base.MultitypeFuncGraph("mul")
|
||||
mul = base.MultitypeFuncGraph("mul", True)
|
||||
"""
|
||||
`mul` is a metafuncgraph object which will multiply two objects according to input type
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
negative = base.MultitypeFuncGraph("negative")
|
||||
negative = base.MultitypeFuncGraph("negative", True)
|
||||
"""
|
||||
`negative` is a metafuncgraph object which will give the negative of an object according to its input type
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
not_equal = base.MultitypeFuncGraph("not_equal")
|
||||
not_equal = base.MultitypeFuncGraph("not_equal", True)
|
||||
"""
|
||||
not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type
|
||||
using ".register" decorator
|
||||
|
|
|
@ -22,7 +22,7 @@ from ... import functional as F
|
|||
from ... import operations as P
|
||||
|
||||
|
||||
ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf')
|
||||
ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf', True)
|
||||
"""
|
||||
`ones_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
pow_ = base.MultitypeFuncGraph("pow")
|
||||
pow_ = base.MultitypeFuncGraph("pow", True)
|
||||
"""
|
||||
`pow` is a metafuncgraph object which will compute the pow of two objects
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
sub = base.MultitypeFuncGraph("sub")
|
||||
sub = base.MultitypeFuncGraph("sub", True)
|
||||
"""
|
||||
`sub` is a metafuncgraph object which will compute the subtraction of two objects
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops.composite import base
|
|||
|
||||
# uadd is a metagraph object which will return operation result regarding input
|
||||
# using ".register" decorator
|
||||
uadd = base.MultitypeFuncGraph("uadd")
|
||||
uadd = base.MultitypeFuncGraph("uadd", True)
|
||||
|
||||
@uadd.register("Tensor")
|
||||
@uadd.register("Number")
|
||||
|
|
|
@ -19,7 +19,7 @@ from ...composite import base
|
|||
from ... import functional as F
|
||||
|
||||
|
||||
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf')
|
||||
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
|
||||
"""
|
||||
`zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type
|
||||
using ".register" decorator.
|
||||
|
|
|
@ -21,7 +21,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
|
|||
from .primitive import Primitive
|
||||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
from .._extends import builtin_operations as BP
|
||||
|
||||
typeof = Primitive('typeof')
|
||||
hastype = Primitive('hastype')
|
||||
|
@ -182,5 +181,6 @@ tensor_operator_registry.register('__gt__', tensor_gt)
|
|||
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
# support GE backend for no compare operators
|
||||
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
|
||||
"""Operators for gradients."""
|
||||
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from .. import signature as sig
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from .._utils import get_concat_offset
|
||||
|
@ -1500,7 +1499,7 @@ class RefToEmbed(Primitive):
|
|||
>>> return key, self.weight
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
sig.make_sig('variable', sig.sig_rw.RW_REF),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -28,10 +28,7 @@ import numpy as np
|
|||
from .._utils import get_concat_offset
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import typing
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
|
@ -44,9 +41,9 @@ class _ScatterOp(PrimitiveWithInfer):
|
|||
Define Scatter operators
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('updates', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
||||
|
@ -1396,7 +1393,7 @@ class Tile(PrimitiveWithInfer):
|
|||
validator.check_value_type("shape", multiples_v, [tuple], self.name)
|
||||
for i, multiple in enumerate(multiples_v):
|
||||
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
|
||||
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
|
||||
validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
|
||||
len_sub = len(multiples_v) - len(x_shp)
|
||||
multiples_w = None
|
||||
if len_sub == 0:
|
||||
|
|
|
@ -18,9 +18,7 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
from ... import context
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -68,7 +66,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
|||
Define binary operators.
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (sig_dtype.T, sig_dtype.T)
|
||||
__mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
@ -186,8 +184,8 @@ class AssignAdd(PrimitiveWithInfer):
|
|||
>>> net(value)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('value', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -237,8 +235,8 @@ class AssignSub(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('value', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -264,8 +262,8 @@ class _Reduce(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('axis', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, ()),
|
||||
sig.make_sig('input_x'),
|
||||
sig.make_sig('axis', default=())
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -22,9 +22,7 @@ from functools import reduce
|
|||
import numpy as np
|
||||
|
||||
from ... import context
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -679,11 +677,11 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
>>> output = op(input_x, scale, bias, mean, variance)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('scale', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('bias', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('mean', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('variance', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -1722,13 +1720,11 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|||
Please refer to the usage in nn.ApplyMomentum.
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T1),
|
||||
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2)
|
||||
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('gradient', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('momentum', dtype=sig.sig_dtype.T2),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3146,23 +3142,17 @@ class FusedSparseAdam(PrimitiveWithInfer):
|
|||
>>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1_power', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta2_power', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta2', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('epsilon', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3285,23 +3275,17 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
|
|||
>>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1_power', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta2_power', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta2', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('epsilon', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3394,11 +3378,11 @@ class FusedSparseFtrl(PrimitiveWithInfer):
|
|||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3492,13 +3476,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
|
|||
>>> output = net(grad, indices)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('l1', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('l2', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3754,16 +3738,15 @@ class ApplyAdaMax(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T1),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
|
||||
('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T5),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('beta1', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('beta2', dtype=sig.sig_dtype.T4),
|
||||
sig.make_sig('epsilon', dtype=sig.sig_dtype.T5),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3873,14 +3856,13 @@ class ApplyAdadelta(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum_update', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('rho', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum_update', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('rho', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('epsilon', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -3971,10 +3953,10 @@ class ApplyAdagrad(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4054,10 +4036,10 @@ class ApplyAdagradV2(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4137,10 +4119,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4224,10 +4206,10 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4313,12 +4295,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('l1', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('l2', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4418,13 +4400,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('l1', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('l2', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T4),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4508,14 +4490,13 @@ class ApplyAddSign(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T3),
|
||||
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('alpha', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('sign_decay', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('beta', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4618,14 +4599,13 @@ class ApplyPowerSign(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('logbase', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('sign_decay', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4704,9 +4684,9 @@ class ApplyGradientDescent(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('alpha', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('delta', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -4777,11 +4757,11 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
|
||||
('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('alpha', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('l1', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('l2', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('delta', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -5032,11 +5012,11 @@ class SparseApplyFtrl(PrimitiveWithCheck):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -5126,11 +5106,11 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
|
||||
"""Other operators."""
|
||||
import functools
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
@ -53,8 +51,8 @@ class Assign(Primitive):
|
|||
>>> net(x)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('value', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -14,17 +14,13 @@
|
|||
# ============================================================================
|
||||
|
||||
"""primitive"""
|
||||
|
||||
import inspect
|
||||
import copy
|
||||
from mindspore.common.api import _wrap_func
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore import context
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from .._c_expression import signature_rw as sig_rw
|
||||
from .._c_expression import signature_kind as sig_kind
|
||||
from .._c_expression import signature_dtype as sig_dtype
|
||||
|
||||
from . import signature as sig
|
||||
|
||||
class Primitive(Primitive_):
|
||||
"""
|
||||
|
@ -54,24 +50,21 @@ class Primitive(Primitive_):
|
|||
self._update_parameter = False
|
||||
Primitive_.__init__(self, name, self)
|
||||
if hasattr(self.__class__, '__mindspore_signature__'):
|
||||
sig = self._fill_signature(self.__class__.__mindspore_signature__)
|
||||
self.set_signatures(sig)
|
||||
out = self._fill_signature(self.__class__.__mindspore_signature__)
|
||||
self.set_signatures(out)
|
||||
|
||||
def _fill_signature(self, signatures):
|
||||
"""fills signature."""
|
||||
signatures_new = []
|
||||
for signature in signatures:
|
||||
if isinstance(signature, sig_dtype):
|
||||
signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD,
|
||||
sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature))
|
||||
if isinstance(signature, sig.Signature):
|
||||
signatures_new.append(signature)
|
||||
elif isinstance(signature, sig.sig_dtype):
|
||||
signatures_new.append(sig.make_sig(dtype=signature))
|
||||
else:
|
||||
if len(signature) < 3:
|
||||
raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
|
||||
if len(signature) == 3:
|
||||
signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE)
|
||||
if len(signature) == 4:
|
||||
signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,)
|
||||
signatures_new.append(signature)
|
||||
signatures_new.append(sig.make_sig(*signature))
|
||||
return tuple(signatures_new)
|
||||
|
||||
def _clone(self):
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""signature"""
|
||||
|
||||
from .._c_expression import signature_rw as sig_rw
|
||||
from .._c_expression import signature_kind as sig_kind
|
||||
from .._c_expression import signature_dtype as sig_dtype
|
||||
from .._c_expression import Signature
|
||||
|
||||
|
||||
def make_sig(name="var", rw=sig_rw.RW_READ,
|
||||
kind=sig_kind.KIND_POSITIONAL_KEYWORD,
|
||||
default=sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
dtype=sig_dtype.T_EMPTY_DEFAULT_VALUE):
|
||||
"""
|
||||
Make signature for one argument.
|
||||
|
||||
See `ApplyMomentum` in `mindspore.ops.operation.nn_ops` as a example.
|
||||
|
||||
Args:
|
||||
name (bool): Argument name. Default: "var".
|
||||
rw (:class:`mindspore.ops.signature.sig_rw`): Tag the argument attribute for write and read. Choose in
|
||||
[sig_rw.RW_READ, sig_rw.RW_WRITE, sig_rw.RW_REF]`, tag if the argument will update the input.
|
||||
`sig_rw.RW_READ` for read only argument and `sig_rw.RW_WRITE` for write only argument. `sig_rw.RW_READ`
|
||||
for the argument both need read and write. Default: sig_rw.RW_READ.
|
||||
kind (:class:`mindspore.ops.signature.kind`): Choose in `[signature_kind.KIND_POSITIONAL_KEYWORD,
|
||||
signature_kind.KIND_VAR_POSITIONAL, signature_kind.KIND_KEYWORD_ONLY, signature_kind.KIND_VAR_KEYWARD]`.
|
||||
The meaning is the same as python argument kind, please refer to the python document.
|
||||
Default: sig_kind.KIND_POSITIONAL_KEYWORD.
|
||||
default (Any): The default value of argument or `sig_kind.KIND_EMPTY_DEFAULT_VALUE` for no default value.
|
||||
Default: sig_kind.KIND_EMPTY_DEFAULT_VALUE.
|
||||
dtype (:class:`mindspore.ops.signature.sig_dtype`): Choose in `signature_dtype.T` or
|
||||
`signature_dtype.T1` to `signature_dtype.T9` or `sig_dtype.T_EMPTY_DEFAULT_VALUE` for no constraints.
|
||||
If the signature of one argument is the same as another argument, we will perform auto type convert
|
||||
between them. If any `sig_rw.RW_WRITE` argument, we will try to convert the other arguments to the
|
||||
`sig_rw.RW_WRITE` argument. Default: sig_dtype.T_EMPTY_DEFAULT_VALUE.
|
||||
|
||||
Returns:
|
||||
:class:`mindspore.ops.signature.Signature`, signature for one argument.
|
||||
"""
|
||||
return Signature(name, rw, kind, default, dtype)
|
|
@ -136,13 +136,15 @@ class NetForCast(nn.Cell):
|
|||
super(NetForCast, self).__init__()
|
||||
self.concat = P.Concat()
|
||||
self.x1 = Tensor(1.0, mstype.float32)
|
||||
self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
|
||||
|
||||
def construct(self, x0):
|
||||
x = self.x1 * x0
|
||||
x = self.x1 * x0 * self.x2
|
||||
return x
|
||||
|
||||
|
||||
def test_cast():
|
||||
context.set_context(save_graphs=True)
|
||||
x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
|
||||
net = NetForCast()
|
||||
net.add_flags_recursive(fp16=True)
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore._c_expression import signature_dtype as sig_dtype
|
||||
from mindspore._c_expression import signature_kind as sig_kind
|
||||
from mindspore._c_expression import signature_rw as sig_rw
|
||||
from mindspore.ops.signature import sig_rw, sig_dtype, make_sig
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
|
@ -126,9 +124,9 @@ class CustomOP(PrimitiveWithInfer):
|
|||
|
||||
class CustomOP2(PrimitiveWithInfer):
|
||||
__mindspore_signature__ = (
|
||||
('p1', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('p2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('p3', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
make_sig('p1', sig_rw.RW_WRITE, dtype=sig_dtype.T),
|
||||
make_sig('p2', dtype=sig_dtype.T),
|
||||
make_sig('p3', dtype=sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue