forked from mindspore-Ecosystem/mindspore
add operator diag and diag_part
This commit is contained in:
parent
f1b722297e
commit
b12e6ff780
|
@ -178,6 +178,8 @@ const char kNameLARSUpdate[] = "LARSUpdate";
|
|||
const char kNameRound[] = "Round";
|
||||
const char kNamePrint[] = "Print";
|
||||
const char kNameApplyFtrl[] = "ApplyFtrl";
|
||||
const char kNameDiag[] = "Diag";
|
||||
const char kNameDiagPart[] = "DiagPart";
|
||||
|
||||
// -----------------OpAdapter initialization--------------
|
||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||
|
@ -357,7 +359,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
||||
{string(kNameSign), ADPT_DESC(Sign)},
|
||||
{string(kNameRound), ADPT_DESC(Round)},
|
||||
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}};
|
||||
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
|
||||
{string(kNameDiag), ADPT_DESC(Diag)},
|
||||
{string(kNameDiagPart), ADPT_DESC(DiagPart)}};
|
||||
#ifdef ENABLE_GE
|
||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||
#endif
|
||||
|
|
|
@ -1173,6 +1173,16 @@ INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INP
|
|||
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// Diag
|
||||
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Diag) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// DiagPart
|
||||
INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
// Print
|
||||
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||
|
|
|
@ -435,6 +435,10 @@ DECLARE_OP_ADAPTER(Round)
|
|||
DECLARE_OP_USE_OUTPUT(Round)
|
||||
DECLARE_OP_ADAPTER(ApplyFtrl)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
||||
DECLARE_OP_ADAPTER(Diag)
|
||||
DECLARE_OP_USE_OUTPUT(Diag)
|
||||
DECLARE_OP_ADAPTER(DiagPart)
|
||||
DECLARE_OP_USE_OUTPUT(DiagPart)
|
||||
#ifdef ENABLE_GE
|
||||
DECLARE_OP_ADAPTER(Print)
|
||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||
|
|
|
@ -408,3 +408,25 @@ def get_bprop_depth_to_space(self):
|
|||
return (op(dout),)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Diag)
|
||||
def get_bprop_diag(self):
|
||||
"""Generate bprop for Diag"""
|
||||
op = P.DiagPart()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (op(dout),)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.DiagPart)
|
||||
def get_bprop_diag_part(self):
|
||||
"""Generate bprop for DiagPart"""
|
||||
op = P.Diag()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (op(dout),)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -20,7 +20,7 @@ A collection of operators to build nerual networks or computing functions.
|
|||
"""
|
||||
|
||||
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
|
||||
Diag, DType, ExpandDims, Eye,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
|
@ -208,6 +208,7 @@ __all__ = [
|
|||
"Cos",
|
||||
"ACos",
|
||||
"Diag",
|
||||
"DiagPart",
|
||||
'Eye',
|
||||
'Assign',
|
||||
'AssignAdd',
|
||||
|
|
|
@ -1615,37 +1615,96 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
class Diag(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
||||
Extract or construct a diagonal array.
|
||||
Construct a diagonal tensor with a given diagonal values.
|
||||
|
||||
If input is a 2-D tensor, returns the diagonal of the input with the given offset. If
|
||||
input is a 1-D tensor, returns the array of diagonals. If you use this function
|
||||
to extract the diagonal and want to write to the result array, see the more
|
||||
detailed documentation for "numpy.diagonal", whether you return a copy or a
|
||||
view depends on the version of numpy you are using.
|
||||
Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
|
||||
rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
|
||||
:math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - 1-D tensor or 2-D tensor.
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor([1, 2, 3, 4])
|
||||
>>> diag = P.Diag()
|
||||
>>> diag(x)
|
||||
[[1, 0, 0, 0],
|
||||
[0, 2, 0, 0],
|
||||
[0, 0, 3, 0],
|
||||
[0, 0, 0, 4]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init Diag"""
|
||||
|
||||
def infer_type(self, x):
|
||||
args = {"x_dtype": x}
|
||||
validator.check_subclass('input_x', x, mstype.tensor)
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
return x
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||
return x_type
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check("x rank", len(x_shape), "", 1, Rel.GE)
|
||||
ret_shape = copy.deepcopy(x_shape)
|
||||
ret_shape = ret_shape + ret_shape
|
||||
return ret_shape
|
||||
|
||||
def infer_value(self, x):
|
||||
validator.check("shape_length", len(x.shape()), "length", [1, 2], Rel.IN)
|
||||
if x is None:
|
||||
return None
|
||||
validator.check("input x rank", len(x.shape()), "", 1)
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
||||
class DiagPart(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
||||
Extract the diagonal part from given tensor.
|
||||
|
||||
Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
|
||||
of rank k with dimensions :math:`[D_1,..., D_k]` where:
|
||||
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input Tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor.
|
||||
|
||||
Examples
|
||||
>>> input_x = Tensor([[1, 0, 0, 0],
|
||||
>>> [0, 2, 0, 0],
|
||||
>>> [0, 0, 3, 0],
|
||||
>>> [0, 0, 0, 4]])
|
||||
>>> diag_part = P.DiagPart()
|
||||
>>> diag_part(x)
|
||||
[1, 2, 3, 4]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init DiagPart"""
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||
return x_type
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if len(x_shape)%2 != 0 or \
|
||||
not x_shape:
|
||||
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||
f"with shapes {x_shape}")
|
||||
length = len(x_shape) // 2
|
||||
ret_shape = x_shape[0:length]
|
||||
return ret_shape
|
||||
|
||||
def infer_value(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
validator.check("x rank", len(x.shape()), "", 2)
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
|
|
@ -942,6 +942,16 @@ test_case_array_ops = [
|
|||
Tensor(np.array([1], np.float32)),
|
||||
Tensor(np.array([1], np.float32)))],
|
||||
'desc_bprop': [[3,]]}),
|
||||
('Diag', {
|
||||
'block': P.Diag(),
|
||||
'desc_inputs': [[4]],
|
||||
'desc_bprop': [[4, 4]],
|
||||
}),
|
||||
('DiagPart', {
|
||||
'block': P.DiagPart(),
|
||||
'desc_inputs': [[4, 4]],
|
||||
'desc_bprop': [[4]],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue