parent
5cba231ba9
commit
00e05f7c34
|
@ -44,10 +44,11 @@ class Embedding(Cell):
|
||||||
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input** (Tensor) - Tensor of shape :math:`(\text{vocab_size})`.
|
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of
|
||||||
|
the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero
|
||||||
|
if larger than vocab_size.
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor of shape :math:`(\text{vocab_size}, \text{embedding_size})`.
|
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.Embedding(20000, 768, True)
|
>>> net = nn.Embedding(20000, 768, True)
|
||||||
|
@ -61,6 +62,7 @@ class Embedding(Cell):
|
||||||
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
|
||||||
super(Embedding, self).__init__()
|
super(Embedding, self).__init__()
|
||||||
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||||
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.use_one_hot = use_one_hot
|
self.use_one_hot = use_one_hot
|
||||||
|
|
|
@ -144,7 +144,7 @@ class Merge(PrimitiveWithInfer):
|
||||||
One and only one of the inputs should be selected as the output
|
One and only one of the inputs should be selected as the output
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **inputs** (Tuple) - The data to be merged. All tuple elements should have same data type.
|
- **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements should have same data type.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
|
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
|
||||||
|
|
Loading…
Reference in New Issue