forked from mindspore-Ecosystem/mindspore
add operator diag and diag_part
This commit is contained in:
parent
f1b722297e
commit
b12e6ff780
|
@ -178,6 +178,8 @@ const char kNameLARSUpdate[] = "LARSUpdate";
|
||||||
const char kNameRound[] = "Round";
|
const char kNameRound[] = "Round";
|
||||||
const char kNamePrint[] = "Print";
|
const char kNamePrint[] = "Print";
|
||||||
const char kNameApplyFtrl[] = "ApplyFtrl";
|
const char kNameApplyFtrl[] = "ApplyFtrl";
|
||||||
|
const char kNameDiag[] = "Diag";
|
||||||
|
const char kNameDiagPart[] = "DiagPart";
|
||||||
|
|
||||||
// -----------------OpAdapter initialization--------------
|
// -----------------OpAdapter initialization--------------
|
||||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||||
|
@ -357,7 +359,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
||||||
{string(kNameSign), ADPT_DESC(Sign)},
|
{string(kNameSign), ADPT_DESC(Sign)},
|
||||||
{string(kNameRound), ADPT_DESC(Round)},
|
{string(kNameRound), ADPT_DESC(Round)},
|
||||||
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}};
|
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
|
||||||
|
{string(kNameDiag), ADPT_DESC(Diag)},
|
||||||
|
{string(kNameDiagPart), ADPT_DESC(DiagPart)}};
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -1173,6 +1173,16 @@ INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INP
|
||||||
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
|
// Diag
|
||||||
|
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||||
|
ATTR_MAP(Diag) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// DiagPart
|
||||||
|
INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
|
||||||
|
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
// Print
|
// Print
|
||||||
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||||
|
|
|
@ -435,6 +435,10 @@ DECLARE_OP_ADAPTER(Round)
|
||||||
DECLARE_OP_USE_OUTPUT(Round)
|
DECLARE_OP_USE_OUTPUT(Round)
|
||||||
DECLARE_OP_ADAPTER(ApplyFtrl)
|
DECLARE_OP_ADAPTER(ApplyFtrl)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
||||||
|
DECLARE_OP_ADAPTER(Diag)
|
||||||
|
DECLARE_OP_USE_OUTPUT(Diag)
|
||||||
|
DECLARE_OP_ADAPTER(DiagPart)
|
||||||
|
DECLARE_OP_USE_OUTPUT(DiagPart)
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
DECLARE_OP_ADAPTER(Print)
|
DECLARE_OP_ADAPTER(Print)
|
||||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||||
|
|
|
@ -408,3 +408,25 @@ def get_bprop_depth_to_space(self):
|
||||||
return (op(dout),)
|
return (op(dout),)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.Diag)
|
||||||
|
def get_bprop_diag(self):
|
||||||
|
"""Generate bprop for Diag"""
|
||||||
|
op = P.DiagPart()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
return (op(dout),)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.DiagPart)
|
||||||
|
def get_bprop_diag_part(self):
|
||||||
|
"""Generate bprop for DiagPart"""
|
||||||
|
op = P.Diag()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
return (op(dout),)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
|
@ -20,7 +20,7 @@ A collection of operators to build nerual networks or computing functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
|
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
|
||||||
Diag, DType, ExpandDims, Eye,
|
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
|
@ -208,6 +208,7 @@ __all__ = [
|
||||||
"Cos",
|
"Cos",
|
||||||
"ACos",
|
"ACos",
|
||||||
"Diag",
|
"Diag",
|
||||||
|
"DiagPart",
|
||||||
'Eye',
|
'Eye',
|
||||||
'Assign',
|
'Assign',
|
||||||
'AssignAdd',
|
'AssignAdd',
|
||||||
|
|
|
@ -1615,37 +1615,96 @@ class StridedSlice(PrimitiveWithInfer):
|
||||||
class Diag(PrimitiveWithInfer):
|
class Diag(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
Extract or construct a diagonal array.
|
Construct a diagonal tensor with a given diagonal values.
|
||||||
|
|
||||||
If input is a 2-D tensor, returns the diagonal of the input with the given offset. If
|
Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
|
||||||
input is a 1-D tensor, returns the array of diagonals. If you use this function
|
rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
|
||||||
to extract the diagonal and want to write to the result array, see the more
|
:math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
|
||||||
detailed documentation for "numpy.diagonal", whether you return a copy or a
|
|
||||||
view depends on the version of numpy you are using.
|
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - 1-D tensor or 2-D tensor.
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor.
|
Tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> input_x = Tensor([1, 2, 3, 4])
|
||||||
>>> diag = P.Diag()
|
>>> diag = P.Diag()
|
||||||
>>> diag(x)
|
>>> diag(x)
|
||||||
|
[[1, 0, 0, 0],
|
||||||
|
[0, 2, 0, 0],
|
||||||
|
[0, 0, 3, 0],
|
||||||
|
[0, 0, 0, 4]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init Diag"""
|
"""init Diag"""
|
||||||
|
|
||||||
def infer_type(self, x):
|
def infer_dtype(self, x_type):
|
||||||
args = {"x_dtype": x}
|
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||||
validator.check_subclass('input_x', x, mstype.tensor)
|
return x_type
|
||||||
validator.check_type_same(args, mstype.number_type)
|
|
||||||
return x
|
def infer_shape(self, x_shape):
|
||||||
|
validator.check("x rank", len(x_shape), "", 1, Rel.GE)
|
||||||
|
ret_shape = copy.deepcopy(x_shape)
|
||||||
|
ret_shape = ret_shape + ret_shape
|
||||||
|
return ret_shape
|
||||||
|
|
||||||
def infer_value(self, x):
|
def infer_value(self, x):
|
||||||
validator.check("shape_length", len(x.shape()), "length", [1, 2], Rel.IN)
|
if x is None:
|
||||||
|
return None
|
||||||
|
validator.check("input x rank", len(x.shape()), "", 1)
|
||||||
|
ret = np.diag(x.asnumpy())
|
||||||
|
return Tensor(ret)
|
||||||
|
|
||||||
|
|
||||||
|
class DiagPart(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Extract the diagonal part from given tensor.
|
||||||
|
|
||||||
|
Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
|
||||||
|
of rank k with dimensions :math:`[D_1,..., D_k]` where:
|
||||||
|
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input Tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
>>> input_x = Tensor([[1, 0, 0, 0],
|
||||||
|
>>> [0, 2, 0, 0],
|
||||||
|
>>> [0, 0, 3, 0],
|
||||||
|
>>> [0, 0, 0, 4]])
|
||||||
|
>>> diag_part = P.DiagPart()
|
||||||
|
>>> diag_part(x)
|
||||||
|
[1, 2, 3, 4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init DiagPart"""
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type):
|
||||||
|
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
if len(x_shape)%2 != 0 or \
|
||||||
|
not x_shape:
|
||||||
|
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||||
|
f"with shapes {x_shape}")
|
||||||
|
length = len(x_shape) // 2
|
||||||
|
ret_shape = x_shape[0:length]
|
||||||
|
return ret_shape
|
||||||
|
|
||||||
|
def infer_value(self, x):
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
validator.check("x rank", len(x.shape()), "", 2)
|
||||||
ret = np.diag(x.asnumpy())
|
ret = np.diag(x.asnumpy())
|
||||||
return Tensor(ret)
|
return Tensor(ret)
|
||||||
|
|
||||||
|
|
|
@ -942,6 +942,16 @@ test_case_array_ops = [
|
||||||
Tensor(np.array([1], np.float32)),
|
Tensor(np.array([1], np.float32)),
|
||||||
Tensor(np.array([1], np.float32)))],
|
Tensor(np.array([1], np.float32)))],
|
||||||
'desc_bprop': [[3,]]}),
|
'desc_bprop': [[3,]]}),
|
||||||
|
('Diag', {
|
||||||
|
'block': P.Diag(),
|
||||||
|
'desc_inputs': [[4]],
|
||||||
|
'desc_bprop': [[4, 4]],
|
||||||
|
}),
|
||||||
|
('DiagPart', {
|
||||||
|
'block': P.DiagPart(),
|
||||||
|
'desc_inputs': [[4, 4]],
|
||||||
|
'desc_bprop': [[4]],
|
||||||
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_other_ops = [
|
test_case_other_ops = [
|
||||||
|
|
Loading…
Reference in New Issue