forked from mindspore-Ecosystem/mindspore
add composite op doc
This commit is contained in:
parent
f984518150
commit
8c6475fd0b
|
@ -27,6 +27,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "ir/signature.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -57,31 +58,27 @@ void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &
|
|||
fn_cache_py_[types] = py_fn;
|
||||
}
|
||||
|
||||
void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
|
||||
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
|
||||
TypePtrList types;
|
||||
for (auto &type_name : types_name) {
|
||||
auto type_ptr = StringToType(type_name);
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error ";
|
||||
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||
py::object type_in = tuple[it];
|
||||
TypePtr type_ptr = nullptr;
|
||||
if (py::isinstance<py::str>(type_in)) {
|
||||
auto type_name = type_in.cast<std::string>();
|
||||
type_ptr = StringToType(type_name);
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << type_name << " convert from string error ";
|
||||
}
|
||||
} else if (py::isinstance<Type>(type_in)) {
|
||||
type_ptr = type_in.cast<TypePtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Register must be string or `mindspore.dtype.Type`";
|
||||
}
|
||||
types.push_back(type_ptr);
|
||||
}
|
||||
Register(types, py_fn);
|
||||
}
|
||||
|
||||
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
|
||||
std::vector<std::string> types_name;
|
||||
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||
py::object name_py = tuple[it];
|
||||
if (py::isinstance<py::str>(name_py)) {
|
||||
types_name.push_back(name_py.cast<std::string>());
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Register must be string";
|
||||
}
|
||||
Register(types_name, py_fn);
|
||||
}
|
||||
|
||||
// Return Exact match if exists, else return non ambiguous sub class match
|
||||
// Return py::none() if matching is ambiguous
|
||||
const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
||||
|
|
|
@ -44,7 +44,6 @@ class MultitypeFuncGraph : public MetaFuncGraph {
|
|||
// Register a method which specialize based on types vectors;
|
||||
virtual void Register(const TypePtrList &types, specialize_fn s_fn);
|
||||
virtual void Register(const TypePtrList &types, const py::function &py_fn);
|
||||
virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
|
||||
virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
|
||||
|
||||
FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
|
||||
|
|
|
@ -396,7 +396,9 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||
<< ret->ToString();
|
||||
node->set_abstract(ret);
|
||||
changed = true;
|
||||
if (ret->cast<abstract::AbstractTuplePtr>()->size() > 0) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
|
|
|
@ -293,7 +293,7 @@ class RowTensor:
|
|||
The dense tensor dense represented by an RowTensor slices has
|
||||
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
|
||||
|
||||
RowTensor can only be used in the `Cell`'s contruct method.
|
||||
RowTensor can only be used in the `Cell`'s construct method.
|
||||
|
||||
It is not supported in pynative mode at the moment.
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ class _CellListBase():
|
|||
by iterator or subscript , it will be interpretated as a list of cells.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(_CellListBase, self).__init__()
|
||||
self.__cell_as_list__ = True
|
||||
|
||||
@abstractmethod
|
||||
|
@ -177,7 +176,8 @@ class CellList(_CellListBase, Cell):
|
|||
(2): ReLU<> >
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
super(CellList, self).__init__()
|
||||
_CellListBase.__init__(self)
|
||||
Cell.__init__(self)
|
||||
if len(args) == 1:
|
||||
self.extend(args[0])
|
||||
|
||||
|
|
|
@ -341,22 +341,42 @@ class GradOperation(GradOperation_):
|
|||
|
||||
class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||
"""
|
||||
Generate multiply graph.
|
||||
Generate overloaded functions.
|
||||
|
||||
MultitypeFuncGraph is a class used to generate graphs for function with different type as input.
|
||||
MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs.
|
||||
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
||||
for the function to be registed. And the object can be called with different type of inputs,
|
||||
and work with `HyperMap` and `Map`.
|
||||
|
||||
Args:
|
||||
name (str): Operator name.
|
||||
read_value (bool): If the registered function not need to set value on Parameter,
|
||||
and all inputs will pass by value. Set `read_value` to True. Default: False.
|
||||
and all inputs will pass by value, set `read_value` to True. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot find matching fn for the given args.
|
||||
ValueError: Cannot find matching functions for the given args.
|
||||
|
||||
Examples:
|
||||
>>> # `add` is a metagraph object which will add two objects according to
|
||||
>>> # input type using ".register" decorator.
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import Primitive, operations as P
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>>
|
||||
>>> scala_add = Primitive('scala_add')
|
||||
>>> tensor_add = P.TensorAdd()
|
||||
>>>
|
||||
>>> add = MultitypeFuncGraph('add')
|
||||
>>> @add.register("Number", "Number")
|
||||
... def add_scala(x, y):
|
||||
... return scala_add(x, y)
|
||||
>>> @add.register("Tensor", "Tensor")
|
||||
... def add_tensor(x, y):
|
||||
... return tensor_add(x, y)
|
||||
>>> add(1, 2)
|
||||
3
|
||||
>>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
|
||||
Tensor(shape=[], dtype=Float32, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, name, read_value=False):
|
||||
|
@ -378,9 +398,25 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
raise ValueError("Cannot find fn match given args.")
|
||||
|
||||
def register(self, *type_names):
|
||||
"""Register a function for the given type string."""
|
||||
"""
|
||||
Register a function for the given type string.
|
||||
|
||||
Args:
|
||||
type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list.
|
||||
|
||||
Return:
|
||||
decorator, a decorator to register the function to run, when called under the
|
||||
types described in `type_names`.
|
||||
"""
|
||||
def deco(fn):
|
||||
types = tuple(map(mstype.typing.str_to_type, type_names))
|
||||
def convert_type(type_input):
|
||||
if isinstance(type_input, str):
|
||||
return mstype.typing.str_to_type(type_input)
|
||||
if not isinstance(type_input, mstype.Type):
|
||||
raise TypeError(f"MultitypeFuncGraph register only support str or {mstype.Type}")
|
||||
return type_input
|
||||
|
||||
types = tuple(map(convert_type, type_names))
|
||||
self.register_fn(type_names, fn)
|
||||
self.entries.append((types, fn))
|
||||
return fn
|
||||
|
@ -391,11 +427,12 @@ class HyperMap(HyperMap_):
|
|||
"""
|
||||
Hypermap will apply the set operation on input sequences.
|
||||
|
||||
Which will apply the operations of every elements of the sequence.
|
||||
Apply the operations to every elements of the sequence or nested sequence. Different
|
||||
from `Map`, the `HyperMap` supports to apply on nested structure.
|
||||
|
||||
Args:
|
||||
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
||||
the operations should be putted in the first input of the instance.
|
||||
the operations should be put in the first input of the instance.
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
||||
|
@ -405,8 +442,28 @@ class HyperMap(HyperMap_):
|
|||
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
||||
|
||||
Outputs:
|
||||
sequence, the output will be same type and same length of sequence from input and the value of each element
|
||||
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
|
||||
Sequence or nested sequence, the sequence of output after applying the function.
|
||||
e.g. `operation(args[0][i], args[1][i])`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
|
||||
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
|
||||
>>> # square all the tensor in the nested list
|
||||
>>>
|
||||
>>> square = MultitypeFuncGraph('square')
|
||||
>>> @square.register("Tensor")
|
||||
... def square_tensor(x):
|
||||
... return F.square(x)
|
||||
>>>
|
||||
>>> common_map = HyperMap()
|
||||
>>> common_map(square, nest_tensor_list)
|
||||
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
||||
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
||||
>>> square_map = HyperMap(square)
|
||||
>>> square_map(nest_tensor_list)
|
||||
((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)),
|
||||
(Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16))
|
||||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
|
@ -434,11 +491,11 @@ class Map(Map_):
|
|||
"""
|
||||
Map will apply the set operation on input sequences.
|
||||
|
||||
Which will apply the operations of every elements of the sequence.
|
||||
Apply the operations to every elements of the sequence.
|
||||
|
||||
Args:
|
||||
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
||||
the operations should be putted in the first input of the instance.
|
||||
the operations should be put in the first input of the instance. Default: None
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
||||
|
@ -448,8 +505,24 @@ class Map(Map_):
|
|||
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
||||
|
||||
Outputs:
|
||||
sequence, the output will be same type and same length of sequence from input and the value of each element
|
||||
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
|
||||
Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32))
|
||||
>>> # square all the tensor in the list
|
||||
>>>
|
||||
>>> square = MultitypeFuncGraph('square')
|
||||
>>> @square.register("Tensor")
|
||||
>>> def square_tensor(x):
|
||||
... return F.square(x)
|
||||
>>>
|
||||
>>> common_map = Map()
|
||||
>>> common_map(square, tensor_list)
|
||||
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
||||
>>> square_map = Map(square)
|
||||
>>> square_map(tensor_list)
|
||||
(Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9))
|
||||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype as mstype
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
tensor_add = P.TensorAdd()
|
||||
|
@ -62,3 +63,27 @@ def test_multitype_tuple():
|
|||
|
||||
def test_multitype_scalar():
|
||||
mainf(1, 2)
|
||||
|
||||
|
||||
add2 = C.MultitypeFuncGraph('add2')
|
||||
@add2.register(mstype.number, mstype.number)
|
||||
def add_scala2(x, y):
|
||||
return scala_add(x, y)
|
||||
|
||||
|
||||
@add2.register(mstype.tensor, mstype.tensor)
|
||||
def add_tensor2(x, y):
|
||||
return tensor_add(x, y)
|
||||
|
||||
|
||||
@ms_function
|
||||
def mainf2(x, y):
|
||||
return add2(x, y)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_multitype_tensor_by_type():
|
||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
out = mainf2(tensor1, tensor2)
|
||||
print(out)
|
||||
|
|
Loading…
Reference in New Issue