forked from mindspore-Ecosystem/mindspore
!10239 add StackPush and StackPop for aicpu
From: @yanzhenxiang2020 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
58da4dbd23
|
@ -44,6 +44,10 @@ constexpr auto kSeed2 = "seed2";
|
|||
constexpr auto kTopK = "TopK";
|
||||
constexpr auto kTopKV2 = "TopKV2";
|
||||
constexpr auto kStack = "Stack";
|
||||
constexpr auto kStackInit = "StackInit";
|
||||
constexpr auto kStackPush = "StackPush";
|
||||
constexpr auto kStackPop = "StackPop";
|
||||
constexpr auto kStackDestroy = "StackDestroy";
|
||||
constexpr auto kEditDistance = "EditDistance";
|
||||
constexpr auto kGatherD = "GatherD";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
|
@ -55,8 +59,8 @@ constexpr auto kCustRunApi = "RunCpuKernel";
|
|||
constexpr auto kDropout2D = "Dropout2D";
|
||||
constexpr auto kDropout3D = "Dropout3D";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
|
||||
kPadAndShift, kDropout3D, kDropout2D};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
|
||||
kDropout2D, kStackInit, kStackPush, kStackPop, kStackDestroy};
|
||||
|
||||
struct AicpuParamHead {
|
||||
uint32_t length; // Total length: include cunstom message
|
||||
|
|
|
@ -70,4 +70,8 @@ from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
|
|||
from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
|
||||
from .meshgrid import _meshgrid_aicpu
|
||||
from .trans_data import _trans_data_aicpu
|
||||
from .stack_push_pop import _stack_init_aicpu
|
||||
from .stack_push_pop import _stack_push_aicpu
|
||||
from .stack_push_pop import _stack_pop_aicpu
|
||||
from .stack_push_pop import _stack_destroy_aicpu
|
||||
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""StackPush and StackPop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
stack_init_op_info = AiCPURegOp("StackInit") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("index", "int") \
|
||||
.get_op_info()
|
||||
|
||||
stack_push_op_info = AiCPURegOp("StackPush") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "src", "required") \
|
||||
.attr("index", "int") \
|
||||
.dtype_format(DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
stack_pop_op_info = AiCPURegOp("StackPop") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.output(0, "dst", "required") \
|
||||
.attr("index", "int") \
|
||||
.dtype_format(DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
stack_destroy_op_info = AiCPURegOp("StackDestroy") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("index", "int") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(stack_init_op_info)
|
||||
def _stack_init_aicpu():
|
||||
"""StackInit aicpu register"""
|
||||
return
|
||||
|
||||
|
||||
@op_info_register(stack_push_op_info)
|
||||
def _stack_push_aicpu():
|
||||
"""StackPush aicpu register"""
|
||||
return
|
||||
|
||||
|
||||
@op_info_register(stack_pop_op_info)
|
||||
def _stack_pop_aicpu():
|
||||
"""StackPop aicpu register"""
|
||||
return
|
||||
|
||||
|
||||
@op_info_register(stack_destroy_op_info)
|
||||
def _stack_destroy_aicpu():
|
||||
"""StackDestroy aicpu register"""
|
||||
return
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Inner operators."""
|
||||
|
||||
import numpy as np
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ... import context
|
||||
|
@ -882,3 +883,129 @@ class Centralization(PrimitiveWithInfer):
|
|||
'dtype': x_dtype,
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class StackInit(PrimitiveWithInfer):
|
||||
"""
|
||||
Create a stack that produces tensors in first-in last-out order.
|
||||
|
||||
After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
|
||||
at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
|
||||
>>> index = 0
|
||||
>>> stack = ops.StackInit(index)
|
||||
>>> push = ops.StackPush(index)
|
||||
>>> pop = ops.StackPop(index, x.shape, x.dtype)
|
||||
>>> destroy = ops.StackDestroy(index)
|
||||
>>> stack()
|
||||
>>> push(x)
|
||||
>>> y = pop()
|
||||
>>> destroy()
|
||||
>>> print(y)
|
||||
[[1 3]
|
||||
[2 0]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, index=1):
|
||||
"""StackInit"""
|
||||
validator.check_value_type("index", index, [int], self.name)
|
||||
|
||||
|
||||
class StackPush(PrimitiveWithInfer):
|
||||
"""
|
||||
Push a tensor onto the stack.
|
||||
|
||||
Before `StackPush`, the stack should be created using `StackInit`.
|
||||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - A tensor to be pushed onto the stack.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
Please refer to the usage of `StackInit`.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, index=1):
|
||||
"""StackPush"""
|
||||
validator.check_value_type("index", index, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['input'], outputs=[])
|
||||
|
||||
|
||||
class StackPop(PrimitiveWithInfer):
|
||||
"""
|
||||
Pop the tensor at the top of the stack.
|
||||
|
||||
Before `StackPop`, the stack should be created using `StackInit`.
|
||||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
shape (tuple): The shape of the tensor at the top of the stack.
|
||||
dtype (mindspore.dtype): The type of the tensor at the top of the stack.
|
||||
|
||||
Outputs:
|
||||
- **output** (Tensor) - The tensor at the top of the stack.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
Please refer to the usage of `StackInit`.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
|
||||
"""StackPop"""
|
||||
validator.check_value_type("index", index, [int], self.name)
|
||||
|
||||
validator.check_value_type('shape type', shape, [list, tuple], self.name)
|
||||
validator.check_int(len(np.array(shape).shape), 1, Rel.EQ, "dim of shape", self.name)
|
||||
for elem in shape:
|
||||
validator.check_int(elem, 1, Rel.GE, 'shape element', self.name)
|
||||
validator.check_value_type('type of shape element', elem, [int], self.name)
|
||||
|
||||
validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
self.init_prim_io_names(inputs=[], outputs=['output'])
|
||||
|
||||
def __infer__(self):
|
||||
return {'shape': (list(self.shape)),
|
||||
'dtype': (self.dtype),
|
||||
'value': None}
|
||||
|
||||
|
||||
class StackDestroy(PrimitiveWithInfer):
|
||||
"""
|
||||
Destroy the stack.
|
||||
|
||||
Before `StackDestroy`, the stack should be created using `StackInit`.
|
||||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
Please refer to the usage of `StackInit`.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, index=1):
|
||||
"""StackDestroy"""
|
||||
validator.check_value_type("index", index, [int], self.name)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as P
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, index=0, shapes_and_types=None):
|
||||
super(Net, self).__init__()
|
||||
shapes_and_types.reverse()
|
||||
self.init = P.StackInit(index)
|
||||
self.push = P.StackPush(index)
|
||||
self.pop = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
|
||||
self.destroy = P.StackDestroy(index)
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
self.init()
|
||||
self.push(x1)
|
||||
self.push(x2)
|
||||
self.push(x3)
|
||||
y1 = self.pop[0]()
|
||||
y2 = self.pop[1]()
|
||||
y3 = self.pop[2]()
|
||||
self.destroy()
|
||||
return y1, y2, y3
|
||||
|
||||
|
||||
class NetTwoStack(nn.Cell):
|
||||
def __init__(self, index=0, shapes_and_types=None):
|
||||
super(NetTwoStack, self).__init__()
|
||||
self.init_0 = P.StackInit(index)
|
||||
self.push_0 = P.StackPush(index)
|
||||
self.pop_0 = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
|
||||
self.destroy_0 = P.StackDestroy(index)
|
||||
|
||||
index += 1
|
||||
self.init_1 = P.StackInit(index)
|
||||
self.push_1 = P.StackPush(index)
|
||||
self.pop_1 = [P.StackPop(index, shape, dtype) for (shape, dtype) in shapes_and_types]
|
||||
self.destroy_1 = P.StackDestroy(index)
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
self.init_0()
|
||||
self.init_1()
|
||||
|
||||
self.push_0(x1)
|
||||
self.push_1(x3)
|
||||
y1 = self.pop_0[0]()
|
||||
z1 = self.pop_1[2]()
|
||||
self.push_0(x2)
|
||||
self.push_0(x3)
|
||||
self.push_1(x1)
|
||||
self.push_1(x2)
|
||||
y2 = self.pop_0[2]()
|
||||
z2 = self.pop_1[1]()
|
||||
y3 = self.pop_0[1]()
|
||||
z3 = self.pop_1[0]()
|
||||
|
||||
self.destroy_0()
|
||||
self.destroy_1()
|
||||
return y1, y2, y3, z1, z2, z3
|
||||
|
||||
|
||||
def test_net():
|
||||
x1 = Tensor(np.random.randn(4,).astype(np.float64))
|
||||
x2 = Tensor(np.random.randn(4, 6).astype(np.float32))
|
||||
x3 = Tensor(np.random.randint(100, size=(3, 4, 5)).astype(np.int32))
|
||||
|
||||
shapes_and_types = []
|
||||
shapes_and_types.append((x1.shape, x1.dtype))
|
||||
shapes_and_types.append((x2.shape, x2.dtype))
|
||||
shapes_and_types.append((x3.shape, x3.dtype))
|
||||
|
||||
net = Net(2018, shapes_and_types)
|
||||
y1, y2, y3 = net(x1, x2, x3)
|
||||
print(x1)
|
||||
print(x2)
|
||||
print(x3)
|
||||
print(y1)
|
||||
print(y2)
|
||||
print(y3)
|
||||
assert np.array_equal(y1.asnumpy(), x3.asnumpy())
|
||||
assert np.array_equal(y2.asnumpy(), x2.asnumpy())
|
||||
assert np.array_equal(y3.asnumpy(), x1.asnumpy())
|
||||
|
||||
|
||||
def test_net_tow_stack():
|
||||
x1 = Tensor(np.random.randn(4,).astype(np.float64))
|
||||
x2 = Tensor(np.random.randn(4, 6).astype(np.float32))
|
||||
x3 = Tensor(np.random.randint(100, size=(3, 4, 5)).astype(np.int32))
|
||||
|
||||
shapes_and_types = []
|
||||
shapes_and_types.append((x1.shape, x1.dtype))
|
||||
shapes_and_types.append((x2.shape, x2.dtype))
|
||||
shapes_and_types.append((x3.shape, x3.dtype))
|
||||
|
||||
net = NetTwoStack(1998, shapes_and_types)
|
||||
y1, y2, y3, z1, z2, z3 = net(x1, x2, x3)
|
||||
print(x1)
|
||||
print(x2)
|
||||
print(x3)
|
||||
print(y1)
|
||||
print(y2)
|
||||
print(y3)
|
||||
assert np.array_equal(y1.asnumpy(), x1.asnumpy())
|
||||
assert np.array_equal(y2.asnumpy(), x3.asnumpy())
|
||||
assert np.array_equal(y3.asnumpy(), x2.asnumpy())
|
||||
|
||||
assert np.array_equal(z1.asnumpy(), x3.asnumpy())
|
||||
assert np.array_equal(z2.asnumpy(), x2.asnumpy())
|
||||
assert np.array_equal(z3.asnumpy(), x1.asnumpy())
|
Loading…
Reference in New Issue