Add pack and unpack

This commit is contained in:
liuxiao 2020-04-06 10:22:47 +08:00
parent dd9a5a385a
commit 47d903ff57
6 changed files with 229 additions and 3 deletions

View File

@ -135,6 +135,7 @@ extern const PrimitivePtr kPrimGatherV2;
extern const PrimitivePtr kPrimSize; extern const PrimitivePtr kPrimSize;
extern const PrimitivePtr kPrimArgMax; extern const PrimitivePtr kPrimArgMax;
extern const PrimitivePtr kPrimPack; extern const PrimitivePtr kPrimPack;
extern const PrimitivePtr kPrimUnpack;
extern const PrimitivePtr kPrimUnsortedSegmentSum; extern const PrimitivePtr kPrimUnsortedSegmentSum;
extern const PrimitivePtr kPrimConcatOffset; extern const PrimitivePtr kPrimConcatOffset;
extern const PrimitivePtr kPrimReshape; extern const PrimitivePtr kPrimReshape;

View File

@ -148,7 +148,8 @@ const char kNameSlice[] = "Slice";
const char kNameAddN[] = "AddN"; const char kNameAddN[] = "AddN";
const char kNameLess[] = "Less"; const char kNameLess[] = "Less";
const char kNameGreater[] = "Greater"; const char kNameGreater[] = "Greater";
const char kNamePack[] = "Stack"; const char kNameStack[] = "Stack";
const char kNameUnstack[] = "Unstack";
const char kNameMerge[] = "Merge"; const char kNameMerge[] = "Merge";
const char kNameGeSwitch[] = "GeSwitch"; const char kNameGeSwitch[] = "GeSwitch";
@ -199,7 +200,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameTopK), ADPT_DESC(TopKV2)}, {string(kNameTopK), ADPT_DESC(TopKV2)},
{string(kNamePack), ADPT_DESC(Pack)}, {string(kNameStack), ADPT_DESC(Pack)},
{string(kNameUnstack), ADPT_DESC(Unpack)},
{string(kNameSplitD), ADPT_DESC(SplitD)}, {string(kNameSplitD), ADPT_DESC(SplitD)},
{string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)},
{string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)},

View File

@ -266,6 +266,30 @@ def get_bprop_gather_v2(self):
return bprop return bprop
@bprop_getters.register(P.Stack)
def get_bprop_stack(self):
"""Generate bprop for Stack"""
axis = self.axis
def bprop(x, out, dout):
stack_grad = P.Unstack(axis)
out = stack_grad(dout)
return (out,)
return bprop
@bprop_getters.register(P.Unstack)
def get_bprop_unstack(self):
"""Generate bprop for Unstack"""
axis = self.axis
def bprop(x, out, dout):
unstack_grad = P.Stack(axis)
out = unstack_grad(dout)
return (out,)
return bprop
@bprop_getters.register(P.StridedSlice) @bprop_getters.register(P.StridedSlice)
def get_bprop_strided_slice(self): def get_bprop_strided_slice(self):
"""Generate bprop for StridedSlice""" """Generate bprop for StridedSlice"""

View File

@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions. A collection of operators to build nerual networks or computing functions.
""" """
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Stack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye, Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation, Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
@ -112,6 +112,8 @@ __all__ = [
'OneHot', 'OneHot',
'GatherV2', 'GatherV2',
'Concat', 'Concat',
'Stack',
'Unstack',
'Tile', 'Tile',
'BiasAdd', 'BiasAdd',
'Gelu', 'Gelu',

View File

@ -1350,6 +1350,150 @@ class Concat(PrimitiveWithInfer):
return out return out
def _get_stack_shape(x_shape, x_type, axis):
"""for satck output shape"""
validator.check_type("shape", x_shape, [tuple])
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT)
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base + 1
for i in range(1, N):
v = x_shape[i]
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base)
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if v[j] != x_shape[0][j]:
raise ValueError("Stack evaluator element %d shape in input can not stack with first element" % i)
out_shape.insert(axis, N)
return out_shape
class Stack(PrimitiveWithInfer):
r"""
Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
Packs the list of tensors in `input_x` into a tensor with rank one higher than
each tensor in `input_x`, by packing them along the `axis` dimension.
Given a list of length `N` of tensors of shape `(A, B, C)`;
If `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
If `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. Etc.
Args:
axis (int): The axis to stack along. Negative values wrap around,
so the valid range is [-(R+1), R+1). Default: 0.
Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
Outputs:
Tensor. A stacked Tensor with the same type as values.
Examples:
>>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
>>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
>>> op = P.Stack()
>>> output = op([data1, data2])
[[0, 1], [2, 3]]
"""
@prim_attr_register
def __init__(self, axis=0):
"""init Stack"""
self.__setattr_flag__ = True
validator.check_type("axis", axis, [int])
self.axis = axis
def __infer__(self, value):
x_shape = value['shape']
x_type = value['dtype']
self.add_prim_attr('num', len(x_shape))
all_shape = _get_stack_shape(x_shape, x_type, self.axis)
out = {'shape': all_shape,
'dtype': x_type[0],
'value': None}
return out
class Unstack(PrimitiveWithInfer):
r"""
Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
Unpacks num tensors from value by chipping it along the axis dimension.
If num is not specified (the default), it is inferred from value's shape.
If value.shape[axis] is not known, ValueError is raised.
For example, given a tensor of shape (A, B, C, D);
If axis == 0 then the i'th tensor in output is the slice value[i, :, :, :] and
each tensor in output will have shape (B, C, D). (Note that the dimension unpacked along is gone, unlike split).
If axis == 1 then the i'th tensor in output is the slice value[:, i, :, :] and
each tensor in output will have shape (A, C, D). Etc.
This is the opposite of stack.
Args:
axis (int): The axis to unstack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is [-R, R).
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
A rank R > 0 Tensor to be unstacked.
Outputs:
A tuple of Tensors, the shape of each objects is same.
Raises:
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())),
or if len(input_x.shape[axis]) not equal to num.
Examples:
>>> unstack = P.Unstack()
>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
>>> output = unstack(x)
([1, 1, 1, 1], [2, 2, 2, 2])
"""
@prim_attr_register
def __init__(self, axis=0):
"""init Unstack"""
self.__setattr_flag__ = True
validator.check_type("axis", axis, [int])
self.axis = axis
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT)
if self.axis < 0:
self.axis = self.axis + dim
output_num = x_shape[self.axis]
validator.check_type("num", output_num, [int])
validator.check_integer("output_num", output_num, 0, Rel.GT)
self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("the dimension which to unstack divides output_num", output_valid_check, 0, Rel.EQ)
out_shapes = []
out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
for _ in range(output_num):
out_shapes.append(tuple(out_shape))
out_dtypes.append(x['dtype'])
out_shapes = tuple(out_shapes)
out_dtypes = tuple(out_dtypes)
out = {'shape': out_shapes,
'dtype': out_dtypes,
'value': None}
return out
class Slice(PrimitiveWithInfer): class Slice(PrimitiveWithInfer):
""" """
Slice a tensor in specified shape. Slice a tensor in specified shape.

View File

@ -80,6 +80,29 @@ class NetForConcat1(nn.Cell):
return self.concat((x1, x2)) return self.concat((x1, x2))
class NetForStackInput(nn.Cell):
def __init__(self, op):
super(NetForStackInput, self).__init__()
self.op = op
self.mul = P.Mul()
def construct(self, *args):
t = ()
for i in range(len(args)):
t = t + (self.mul(args[i], args[i]),)
return self.op(t)
class NetForUnstackInput(nn.Cell):
def __init__(self, op):
super(NetForUnstackInput, self).__init__()
self.op = op
self.mul = P.Mul()
def construct(self, x1):
return self.op((self.mul(x1, x1)))
class NetForFlatten(nn.Cell): class NetForFlatten(nn.Cell):
def __init__(self): def __init__(self):
super(NetForFlatten, self).__init__() super(NetForFlatten, self).__init__()
@ -968,6 +991,36 @@ test_case_array_ops = [
Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))], Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}), 'desc_bprop': [[3,]]}),
('StackV2_0', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[2, 2], [2, 2], [2, 2]],
'desc_bprop':[[3, 2, 2]],
}),
('StackV2_1', {
'block': NetForStackInput(P.Stack(axis=-2)),
'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]],
'desc_bprop':[[3, 2, 3, 3]],
}),
('StackV2_2', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[2, 2]],
'desc_bprop':[[2, 2, 2]],
}),
('StackV2_3', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[128, 128], [128, 128]],
'desc_bprop':[[2, 128, 128]],
}),
('UnstackV2_0', {
'block': NetForUnstackInput(P.Unstack(axis=0)),
'desc_inputs':[[2, 4]],
'desc_bprop':[[4], [4]],
}),
('UnstackV2_1', {
'block': NetForUnstackInput(P.Unstack(axis=-1)),
'desc_inputs':[Tensor(np.array([[1, 1, 1]], np.float32))],
'desc_bprop':[[1], [1], [1]],
}),
('Diag', { ('Diag', {
'block': P.Diag(), 'block': P.Diag(),
'desc_inputs': [[4]], 'desc_inputs': [[4]],