add operator diag and diag_part

This commit is contained in:
zhaozhenlong 2020-03-31 09:34:09 +08:00
parent f1b722297e
commit b12e6ff780
7 changed files with 125 additions and 15 deletions

View File

@ -178,6 +178,8 @@ const char kNameLARSUpdate[] = "LARSUpdate";
const char kNameRound[] = "Round";
const char kNamePrint[] = "Print";
const char kNameApplyFtrl[] = "ApplyFtrl";
const char kNameDiag[] = "Diag";
const char kNameDiagPart[] = "DiagPart";
// -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
@ -357,7 +359,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
{string(kNameSign), ADPT_DESC(Sign)},
{string(kNameRound), ADPT_DESC(Round)},
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}};
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
{string(kNameDiag), ADPT_DESC(Diag)},
{string(kNameDiagPart), ADPT_DESC(DiagPart)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
#endif

View File

@ -1173,6 +1173,16 @@ INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INP
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
// Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Diag) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}};
// DiagPart
INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
#ifdef ENABLE_GE
// Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP;

View File

@ -435,6 +435,10 @@ DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart)
DECLARE_OP_USE_OUTPUT(DiagPart)
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print)

View File

@ -408,3 +408,25 @@ def get_bprop_depth_to_space(self):
return (op(dout),)
return bprop
@bprop_getters.register(P.Diag)
def get_bprop_diag(self):
"""Generate bprop for Diag"""
op = P.DiagPart()
def bprop(x, out, dout):
return (op(dout),)
return bprop
@bprop_getters.register(P.DiagPart)
def get_bprop_diag_part(self):
"""Generate bprop for DiagPart"""
op = P.Diag()
def bprop(x, out, dout):
return (op(dout),)
return bprop

View File

@ -20,7 +20,7 @@ A collection of operators to build nerual networks or computing functions.
"""
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
Diag, DType, ExpandDims, Eye,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
@ -208,6 +208,7 @@ __all__ = [
"Cos",
"ACos",
"Diag",
"DiagPart",
'Eye',
'Assign',
'AssignAdd',

View File

@ -1615,37 +1615,96 @@ class StridedSlice(PrimitiveWithInfer):
class Diag(PrimitiveWithInfer):
r"""
Extract or construct a diagonal array.
Construct a diagonal tensor with a given diagonal values.
If input is a 2-D tensor, returns the diagonal of the input with the given offset. If
input is a 1-D tensor, returns the array of diagonals. If you use this function
to extract the diagonal and want to write to the result array, see the more
detailed documentation for "numpy.diagonal", whether you return a copy or a
view depends on the version of numpy you are using.
Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
:math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
Inputs:
- **input_x** (Tensor) - 1-D tensor or 2-D tensor.
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor.
Examples:
>>> input_x = Tensor([1, 2, 3, 4])
>>> diag = P.Diag()
>>> diag(x)
[[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]]
"""
@prim_attr_register
def __init__(self):
"""init Diag"""
def infer_type(self, x):
args = {"x_dtype": x}
validator.check_subclass('input_x', x, mstype.tensor)
validator.check_type_same(args, mstype.number_type)
return x
def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor)
return x_type
def infer_shape(self, x_shape):
validator.check("x rank", len(x_shape), "", 1, Rel.GE)
ret_shape = copy.deepcopy(x_shape)
ret_shape = ret_shape + ret_shape
return ret_shape
def infer_value(self, x):
validator.check("shape_length", len(x.shape()), "length", [1, 2], Rel.IN)
if x is None:
return None
validator.check("input x rank", len(x.shape()), "", 1)
ret = np.diag(x.asnumpy())
return Tensor(ret)
class DiagPart(PrimitiveWithInfer):
r"""
Extract the diagonal part from given tensor.
Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
of rank k with dimensions :math:`[D_1,..., D_k]` where:
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
Inputs:
- **input_x** (Tensor) - The input Tensor.
Outputs:
Tensor.
Examples
>>> input_x = Tensor([[1, 0, 0, 0],
>>> [0, 2, 0, 0],
>>> [0, 0, 3, 0],
>>> [0, 0, 0, 4]])
>>> diag_part = P.DiagPart()
>>> diag_part(x)
[1, 2, 3, 4]
"""
@prim_attr_register
def __init__(self):
"""init DiagPart"""
def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor)
return x_type
def infer_shape(self, x_shape):
if len(x_shape)%2 != 0 or \
not x_shape:
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}")
length = len(x_shape) // 2
ret_shape = x_shape[0:length]
return ret_shape
def infer_value(self, x):
if x is None:
return None
validator.check("x rank", len(x.shape()), "", 2)
ret = np.diag(x.asnumpy())
return Tensor(ret)

View File

@ -942,6 +942,16 @@ test_case_array_ops = [
Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
('Diag', {
'block': P.Diag(),
'desc_inputs': [[4]],
'desc_bprop': [[4, 4]],
}),
('DiagPart', {
'block': P.DiagPart(),
'desc_inputs': [[4, 4]],
'desc_bprop': [[4]],
}),
]
test_case_other_ops = [