!10239 add StackPush and StackPop for aicpu

From: @yanzhenxiang2020
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-15 16:53:11 +08:00 committed by Gitee
commit 58da4dbd23
5 changed files with 352 additions and 2 deletions

View File

@ -44,6 +44,10 @@ constexpr auto kSeed2 = "seed2";
constexpr auto kTopK = "TopK";
constexpr auto kTopKV2 = "TopKV2";
constexpr auto kStack = "Stack";
constexpr auto kStackInit = "StackInit";
constexpr auto kStackPush = "StackPush";
constexpr auto kStackPop = "StackPop";
constexpr auto kStackDestroy = "StackDestroy";
constexpr auto kEditDistance = "EditDistance";
constexpr auto kGatherD = "GatherD";
constexpr auto kIdentity = "Identity";
@ -55,8 +59,8 @@ constexpr auto kCustRunApi = "RunCpuKernel";
constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D";
const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
kPadAndShift, kDropout3D, kDropout2D};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
kDropout2D, kStackInit, kStackPush, kStackPop, kStackDestroy};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

View File

@ -70,4 +70,8 @@ from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
from .meshgrid import _meshgrid_aicpu
from .trans_data import _trans_data_aicpu
from .stack_push_pop import _stack_init_aicpu
from .stack_push_pop import _stack_push_aicpu
from .stack_push_pop import _stack_pop_aicpu
from .stack_push_pop import _stack_destroy_aicpu
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""StackPush and StackPop op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
stack_init_op_info = AiCPURegOp("StackInit") \
.fusion_type("OPAQUE") \
.attr("index", "int") \
.get_op_info()
stack_push_op_info = AiCPURegOp("StackPush") \
.fusion_type("OPAQUE") \
.input(0, "src", "required") \
.attr("index", "int") \
.dtype_format(DataType.U8_Default) \
.dtype_format(DataType.U16_Default) \
.dtype_format(DataType.U32_Default) \
.dtype_format(DataType.U64_Default) \
.dtype_format(DataType.I8_Default) \
.dtype_format(DataType.I16_Default) \
.dtype_format(DataType.I32_Default) \
.dtype_format(DataType.I64_Default) \
.dtype_format(DataType.F16_Default) \
.dtype_format(DataType.F32_Default) \
.dtype_format(DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default) \
.get_op_info()
stack_pop_op_info = AiCPURegOp("StackPop") \
.fusion_type("OPAQUE") \
.output(0, "dst", "required") \
.attr("index", "int") \
.dtype_format(DataType.U8_Default) \
.dtype_format(DataType.U16_Default) \
.dtype_format(DataType.U32_Default) \
.dtype_format(DataType.U64_Default) \
.dtype_format(DataType.I8_Default) \
.dtype_format(DataType.I16_Default) \
.dtype_format(DataType.I32_Default) \
.dtype_format(DataType.I64_Default) \
.dtype_format(DataType.F16_Default) \
.dtype_format(DataType.F32_Default) \
.dtype_format(DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default) \
.get_op_info()
stack_destroy_op_info = AiCPURegOp("StackDestroy") \
.fusion_type("OPAQUE") \
.attr("index", "int") \
.get_op_info()
@op_info_register(stack_init_op_info)
def _stack_init_aicpu():
"""StackInit aicpu register"""
return
@op_info_register(stack_push_op_info)
def _stack_push_aicpu():
"""StackPush aicpu register"""
return
@op_info_register(stack_pop_op_info)
def _stack_pop_aicpu():
"""StackPop aicpu register"""
return
@op_info_register(stack_destroy_op_info)
def _stack_destroy_aicpu():
"""StackDestroy aicpu register"""
return

View File

@ -15,6 +15,7 @@
"""Inner operators."""
import numpy as np
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ... import context
@ -882,3 +883,129 @@ class Centralization(PrimitiveWithInfer):
'dtype': x_dtype,
'value': None}
return out
class StackInit(PrimitiveWithInfer):
"""
Create a stack that produces tensors in first-in last-out order.
After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
Args:
index (int): The index of the stack.
Supported Platforms:
``Ascend``
Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
>>> index = 0
>>> stack = ops.StackInit(index)
>>> push = ops.StackPush(index)
>>> pop = ops.StackPop(index, x.shape, x.dtype)
>>> destroy = ops.StackDestroy(index)
>>> stack()
>>> push(x)
>>> y = pop()
>>> destroy()
>>> print(y)
[[1 3]
[2 0]]
"""
@prim_attr_register
def __init__(self, index=1):
"""StackInit"""
validator.check_value_type("index", index, [int], self.name)
class StackPush(PrimitiveWithInfer):
"""
Push a tensor onto the stack.
Before `StackPush`, the stack should be created using `StackInit`.
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
Inputs:
- **input** (Tensor) - A tensor to be pushed onto the stack.
Supported Platforms:
``Ascend``
Examples:
Please refer to the usage of `StackInit`.
"""
@prim_attr_register
def __init__(self, index=1):
"""StackPush"""
validator.check_value_type("index", index, [int], self.name)
self.init_prim_io_names(inputs=['input'], outputs=[])
class StackPop(PrimitiveWithInfer):
"""
Pop the tensor at the top of the stack.
Before `StackPop`, the stack should be created using `StackInit`.
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
shape (tuple): The shape of the tensor at the top of the stack.
dtype (mindspore.dtype): The type of the tensor at the top of the stack.
Outputs:
- **output** (Tensor) - The tensor at the top of the stack.
Supported Platforms:
``Ascend``
Examples:
Please refer to the usage of `StackInit`.
"""
@prim_attr_register
def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
"""StackPop"""
validator.check_value_type("index", index, [int], self.name)
validator.check_value_type('shape type', shape, [list, tuple], self.name)
validator.check_int(len(np.array(shape).shape), 1, Rel.EQ, "dim of shape", self.name)
for elem in shape:
validator.check_int(elem, 1, Rel.GE, 'shape element', self.name)
validator.check_value_type('type of shape element', elem, [int], self.name)
validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
self.shape = shape
self.dtype = dtype
self.init_prim_io_names(inputs=[], outputs=['output'])
def __infer__(self):
return {'shape': (list(self.shape)),
'dtype': (self.dtype),
'value': None}
class StackDestroy(PrimitiveWithInfer):
"""
Destroy the stack.
Before `StackDestroy`, the stack should be created using `StackInit`.
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
Supported Platforms:
``Ascend``
Examples:
Please refer to the usage of `StackInit`.
"""
@prim_attr_register
def __init__(self, index=1):
"""StackDestroy"""
validator.check_value_type("index", index, [int], self.name)

View File

@ -0,0 +1,128 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.operations import _inner_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, index=0, shapes_and_types=None):
super(Net, self).__init__()
shapes_and_types.reverse()
self.init = P.StackInit(index)
self.push = P.StackPush(index)
self.pop = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
self.destroy = P.StackDestroy(index)
def construct(self, x1, x2, x3):
self.init()
self.push(x1)
self.push(x2)
self.push(x3)
y1 = self.pop[0]()
y2 = self.pop[1]()
y3 = self.pop[2]()
self.destroy()
return y1, y2, y3
class NetTwoStack(nn.Cell):
def __init__(self, index=0, shapes_and_types=None):
super(NetTwoStack, self).__init__()
self.init_0 = P.StackInit(index)
self.push_0 = P.StackPush(index)
self.pop_0 = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
self.destroy_0 = P.StackDestroy(index)
index += 1
self.init_1 = P.StackInit(index)
self.push_1 = P.StackPush(index)
self.pop_1 = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
self.destroy_1 = P.StackDestroy(index)
def construct(self, x1, x2, x3):
self.init_0()
self.init_1()
self.push_0(x1)
self.push_1(x3)
y1 = self.pop_0[0]()
z1 = self.pop_1[2]()
self.push_0(x2)
self.push_0(x3)
self.push_1(x1)
self.push_1(x2)
y2 = self.pop_0[2]()
z2 = self.pop_1[1]()
y3 = self.pop_0[1]()
z3 = self.pop_1[0]()
self.destroy_0()
self.destroy_1()
return y1, y2, y3, z1, z2, z3
def test_net():
x1 = Tensor(np.random.randn(4,).astype(np.float64))
x2 = Tensor(np.random.randn(4, 6).astype(np.float32))
x3 = Tensor(np.random.randint(100, size=(3, 4, 5)).astype(np.int32))
shapes_and_types = []
shapes_and_types.append((x1.shape, x1.dtype))
shapes_and_types.append((x2.shape, x2.dtype))
shapes_and_types.append((x3.shape, x3.dtype))
net = Net(2018, shapes_and_types)
y1, y2, y3 = net(x1, x2, x3)
print(x1)
print(x2)
print(x3)
print(y1)
print(y2)
print(y3)
assert np.array_equal(y1.asnumpy(), x3.asnumpy())
assert np.array_equal(y2.asnumpy(), x2.asnumpy())
assert np.array_equal(y3.asnumpy(), x1.asnumpy())
def test_net_tow_stack():
x1 = Tensor(np.random.randn(4,).astype(np.float64))
x2 = Tensor(np.random.randn(4, 6).astype(np.float32))
x3 = Tensor(np.random.randint(100, size=(3, 4, 5)).astype(np.int32))
shapes_and_types = []
shapes_and_types.append((x1.shape, x1.dtype))
shapes_and_types.append((x2.shape, x2.dtype))
shapes_and_types.append((x3.shape, x3.dtype))
net = NetTwoStack(1998, shapes_and_types)
y1, y2, y3, z1, z2, z3 = net(x1, x2, x3)
print(x1)
print(x2)
print(x3)
print(y1)
print(y2)
print(y3)
assert np.array_equal(y1.asnumpy(), x1.asnumpy())
assert np.array_equal(y2.asnumpy(), x3.asnumpy())
assert np.array_equal(y3.asnumpy(), x2.asnumpy())
assert np.array_equal(z1.asnumpy(), x3.asnumpy())
assert np.array_equal(z2.asnumpy(), x2.asnumpy())
assert np.array_equal(z3.asnumpy(), x1.asnumpy())