forked from mindspore-Ecosystem/mindspore
!9058 Fix MultiFieldEmbedding Doc Error
From: @huangxinjing Reviewed-by: @stsuteng,@zh_qh Signed-off-by: @stsuteng
This commit is contained in:
commit
5ae37ce350
|
@ -273,7 +273,7 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
||||
r"""
|
||||
Returns a slice of input tensor based on the specified indices based on the filed ids. This operation
|
||||
Returns a slice of input tensor based on the specified indices based on the field ids. This operation
|
||||
supports looking up embeddings within multi hot and one hot fields simultaneously.
|
||||
|
||||
Note:
|
||||
|
@ -284,7 +284,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
specified 'axis = 0' to lookup table.
|
||||
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
|
||||
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
||||
output will be zeros if the summed of absolute weight of the field is zero. This class only
|
||||
output will be zeros if the sum of absolute weight of the field is zero. This class only
|
||||
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
|
||||
|
||||
Args:
|
||||
|
@ -300,29 +300,31 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
operator (string): The pooling method for the features in one field. Support `SUM`, `MEAN` and 'MAX'
|
||||
operator (string): The pooling method for the features in one field. Support 'SUM, 'MEAN' and 'MAX'
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
||||
and the exceeding part will be filled with 0 in the output. Input_indices must be a 2d tensor in
|
||||
Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in
|
||||
this interface. Type is Int16, Int32, Int64.
|
||||
- **input_values** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifies the weights of elements of the input_indices. The lookout vector will multiply with
|
||||
the input_values. Type is Float32.
|
||||
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
||||
Specifics the field id of elements of the input_indices. Type is Type is Int16, Int32, Int64.
|
||||
Specifies the field id of elements of the input_indices. Type is Type is Int16, Int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
|
||||
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
|
||||
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
|
||||
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
|
||||
>>> out = net(input_indices, input_values, field_ids)
|
||||
>>> print(result)
|
||||
>>> print(out)
|
||||
[[[-0.00478983 -0.00772568]
|
||||
[-0.00968955 -0.00064902]]
|
||||
[[-0.01251151 -0.01251151]
|
||||
|
@ -335,7 +337,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
||||
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
||||
slice_mode, feature_num_list, max_norm, sparse)
|
||||
self.field_size = field_size
|
||||
self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name)
|
||||
self.operator = operator
|
||||
|
||||
self.mul = P.Mul()
|
||||
|
|
Loading…
Reference in New Issue