forked from mindspore-Ecosystem/mindspore
amend randperm
amned amend randperm and add NoRepeatNGram ops for aicpu
This commit is contained in:
parent
80fe11ac7b
commit
d4bca2b9c3
|
@ -14,6 +14,7 @@
|
|||
|
||||
"""aicpu ops"""
|
||||
from .unique import _unique_aicpu
|
||||
from .no_repeat_ngram import _no_repeat_ngram_aicpu
|
||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||
from .embedding_lookup import _embedding_lookup_aicpu
|
||||
from .padding import _padding_aicpu
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NoRepeatNGram op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
no_repeat_ngram_op_info = AiCPURegOp("NoRepeatNGram") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "state_seq", "required") \
|
||||
.input(1, "log_probs", "required") \
|
||||
.output(0, "out", "required") \
|
||||
.attr("ngram_size", "int") \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(no_repeat_ngram_op_info)
|
||||
def _no_repeat_ngram_aicpu():
|
||||
"""NoRepeatNGram AiCPU register"""
|
||||
return
|
|
@ -18,16 +18,18 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
|
|||
|
||||
randperm_op_info = AiCPURegOp("Randperm") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "n", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("n", "int") \
|
||||
.dtype_format(DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default) \
|
||||
.attr("max_length", "int") \
|
||||
.attr("pad", "int") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(randperm_op_info)
|
||||
|
|
|
@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
|
|||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast
|
||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||
|
@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
|
||||
|
||||
from .random_ops import (Randperm, RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
LogUniformCandidateSampler)
|
||||
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
|
||||
|
@ -200,6 +200,7 @@ __all__ = [
|
|||
'HSwish',
|
||||
'HSigmoid',
|
||||
'Tanh',
|
||||
'NoRepeatNGram',
|
||||
'Randperm',
|
||||
'RandomChoiceWithMask',
|
||||
'StandardNormal',
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
import numbers
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.dtype import tensor, dtype_to_pytype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
|
||||
|
@ -59,3 +61,114 @@ class ScalarCast(PrimitiveWithInfer):
|
|||
'dtype': t['value'],
|
||||
'value': value}
|
||||
return out
|
||||
|
||||
|
||||
class Randperm(PrimitiveWithInfer):
|
||||
"""
|
||||
Generates random samples from 0 to n-1.
|
||||
|
||||
Args:
|
||||
max_length (int): Number of items expected to get and the number must be greater than 0. Default: 1.
|
||||
pad (int): The pad value to be filled. Default: -1.
|
||||
dtype (mindspore.dtype): The type of output. Default: mindspore.int32.
|
||||
|
||||
Inputs:
|
||||
- **n** (Tensor[int]) - The input tensor with shape: (1,) and the number must be in (0, `max_length`].
|
||||
Default: 1.
|
||||
|
||||
Outputs:
|
||||
- **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> randperm = ops.Randperm(max_length=30, pad=-1)
|
||||
>>> n = Tensor([20], dtype=mindspore.int32)
|
||||
>>> output = randperm(n)
|
||||
>>> print(output)
|
||||
[15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 14 1 12 3 7
|
||||
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, max_length=1, pad=-1, dtype=mstype.int32):
|
||||
"""Initialize Randperm"""
|
||||
validator.check_value_type("pad", pad, [int], self.name)
|
||||
validator.check_value_type("max_length", max_length, [int], self.name)
|
||||
validator.check_int(max_length, 1, Rel.GE, "1", self.name)
|
||||
self.dtype = dtype
|
||||
self.max_length = max_length
|
||||
self.init_prim_io_names(inputs=[], outputs=['output'])
|
||||
|
||||
def infer_shape(self, n_shape):
|
||||
validator.check_int(len(n_shape), 1, Rel.EQ, "rank_of_n", self.name)
|
||||
validator.check_int(n_shape[0], 1, Rel.EQ, "length_of_n", self.name)
|
||||
return [self.max_length]
|
||||
|
||||
def infer_dtype(self, n_type):
|
||||
validator.check_type_name("n_type", n_type, mstype.int32, self.name)
|
||||
|
||||
valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64)
|
||||
validator.check_type_name("dtype", self.dtype, valid_values, self.name)
|
||||
return self.dtype
|
||||
|
||||
|
||||
class NoRepeatNGram(PrimitiveWithInfer):
|
||||
"""
|
||||
Update log_probs with repeat n-grams.
|
||||
|
||||
Args:
|
||||
ngram_size (int): Size of n-grams, must be greater than 0. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **state_seq** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m).
|
||||
- **log_probs** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size).
|
||||
The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated.
|
||||
|
||||
Outputs:
|
||||
- **log_probs** (Tensor) - The output Tensor with same shape and type as original `log_probs`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
|
||||
>>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
||||
[9, 3, 9, 5, 4, 1, 5]],
|
||||
[[4, 8, 6, 4, 5, 6, 4],
|
||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
||||
>>> log_probs = Tensor([[[0.75858542, 0.8437121 , 0.69025469, 0.79379992, 0.27400691,
|
||||
0.84709179, 0.78771346, 0.68587179, 0.22943851, 0.17682976]],
|
||||
[[0.99401879, 0.77239773, 0.81973878, 0.32085208, 0.59944118,
|
||||
0.3125177, 0.52604189, 0.77111461, 0.98443699, 0.71532898]]], dtype=mindspore.float32)
|
||||
>>> output = no_repeat_ngram(state_seq, log_probs)
|
||||
>>> print(output)
|
||||
[[[0.75858542 -3.4028235e+38 0.69025469 0.79379992 0.27400691
|
||||
-3.4028235e+38 0.78771346 0.68587179 0.22943851 0.17682976]]
|
||||
[[0.99401879 0.77239773 0.81973878 0.32085208 0.59944118
|
||||
-3.4028235e+38 0.52604189 0.77111461 0.98443699 0.71532898]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ngram_size=1):
|
||||
"""NoRepeatNGram Randperm"""
|
||||
validator.check_value_type("ngram_size", ngram_size, [int], self.name)
|
||||
validator.check_int(ngram_size, 1, Rel.GE, "ngram_size", self.name)
|
||||
self.ngram_size = ngram_size
|
||||
self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs'])
|
||||
|
||||
def infer_shape(self, seq_shape, log_shape):
|
||||
validator.check_int(len(seq_shape), 3, Rel.EQ, "rank_of_seq", self.name)
|
||||
validator.check_int(len(log_shape), 3, Rel.EQ, "rank_of_log", self.name)
|
||||
validator.check_int(seq_shape[0], log_shape[0], Rel.EQ, "seq_shape shape[0]", self.name)
|
||||
validator.check_int(seq_shape[1], log_shape[1], Rel.EQ, "seq_shape shape[1]", self.name)
|
||||
validator.check_int(self.ngram_size, seq_shape[2] + 1, Rel.LE, "ngram_size", self.name)
|
||||
return log_shape
|
||||
|
||||
def infer_dtype(self, seq_type, log_type):
|
||||
validator.check_type_name("seq_type", seq_type, mstype.int32, self.name)
|
||||
valid_values = (mstype.float16, mstype.float32, mstype.float64)
|
||||
validator.check_type_name("log_type", log_type, valid_values, self.name)
|
||||
return log_type
|
||||
|
|
|
@ -345,46 +345,6 @@ class UniformReal(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Randperm(PrimitiveWithInfer):
|
||||
"""
|
||||
Generates random samples from 0 to n-1.
|
||||
|
||||
Args:
|
||||
n (int): Number of items expected to get and the number must be greater than 0. Default: 1.
|
||||
dtype (mindspore.dtype): The type of output. Default: mindspore.int32.
|
||||
|
||||
Outputs:
|
||||
- **output** (Tensor) - The output Tensor with shape :math:`(n,)` and type: dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> randperm = ops.Randperm(20)
|
||||
>>> output = randperm()
|
||||
>>> print(output)
|
||||
[15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 14 1 12 3 7]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, n=1, dtype=mstype.int32):
|
||||
"""Initialize Randperm"""
|
||||
Validator.check_value_type("n", n, [int], self.name)
|
||||
self.dtype = dtype
|
||||
self.n = n
|
||||
self.init_prim_io_names(inputs=[], outputs=['output'])
|
||||
|
||||
def infer_shape(self):
|
||||
Validator.check_int(self.n, 1, Rel.GE, "1", self.name)
|
||||
return [self.n]
|
||||
|
||||
def infer_dtype(self):
|
||||
valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64)
|
||||
Validator.check_type_name("dtype", self.dtype, valid_values, self.name)
|
||||
return self.dtype
|
||||
|
||||
|
||||
class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||
"""
|
||||
Generates a random sample as index tensor with a mask tensor from a given tensor.
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
FLT_MAX = 3.4028235e+38
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, ngram_size):
|
||||
super(Net, self).__init__()
|
||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||
|
||||
def construct(self, state_seq, log_probs):
|
||||
return self.no_repeat_ngram(state_seq, log_probs)
|
||||
|
||||
|
||||
def test_net():
|
||||
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
||||
[9, 3, 9, 5, 4, 1, 5],
|
||||
[4, 7, 9, 1, 9, 6, 1],
|
||||
[7, 6, 4, 2, 9, 1, 5],
|
||||
[7, 5, 8, 9, 9, 3, 9]],
|
||||
[[7, 7, 2, 7, 9, 9, 4],
|
||||
[3, 4, 7, 4, 7, 6, 8],
|
||||
[1, 9, 5, 7, 6, 9, 3],
|
||||
[4, 8, 6, 4, 5, 6, 4],
|
||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
||||
|
||||
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
|
||||
expect_log_probs = log_probs.asnumpy().copy()
|
||||
expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||
expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||
|
||||
net = Net(ngram_size=3)
|
||||
output = net(state_seq, log_probs)
|
||||
|
||||
print(expect_log_probs)
|
||||
print(output)
|
||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
|
@ -15,23 +15,24 @@
|
|||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, n=1, dtype=mindspore.int32):
|
||||
def __init__(self, max_length, pad, dtype=mindspore.int32):
|
||||
super(Net, self).__init__()
|
||||
self.randperm = P.Randperm(n, dtype)
|
||||
self.randperm = P.Randperm(max_length, pad, dtype)
|
||||
|
||||
def construct(self):
|
||||
return self.randperm()
|
||||
def construct(self, n):
|
||||
return self.randperm(n)
|
||||
|
||||
|
||||
def test_net():
|
||||
net = Net()
|
||||
output = net()
|
||||
net = Net(max_length=1, pad=-1)
|
||||
output = net(Tensor([1], mindspore.int32))
|
||||
|
||||
print(output)
|
||||
print(output.shape)
|
||||
|
@ -42,15 +43,18 @@ def test_net():
|
|||
|
||||
|
||||
def test_net_n20():
|
||||
net = Net(20, mindspore.uint64)
|
||||
output = net()
|
||||
net = Net(max_length=30, pad=-1, dtype=mindspore.int32)
|
||||
output = net(Tensor([20], dtype=mindspore.int32))
|
||||
|
||||
print(output)
|
||||
assert output.shape == (20,)
|
||||
assert output.dtype == mindspore.uint64
|
||||
assert output.shape == (30,)
|
||||
assert output.dtype == mindspore.int32
|
||||
|
||||
sample_set = set()
|
||||
for i in output.asnumpy():
|
||||
assert i not in sample_set
|
||||
assert 0 <= i < 20
|
||||
sample_set.add(i)
|
||||
for index, i in enumerate(output.asnumpy()):
|
||||
if index < 20:
|
||||
assert i not in sample_set
|
||||
assert 0 <= i < 20
|
||||
sample_set.add(i)
|
||||
else:
|
||||
assert i == -1
|
||||
|
|
Loading…
Reference in New Issue