forked from mindspore-Ecosystem/mindspore
!9085 Fix MultiFieldEmbedding Doc Error
From: @huangxinjing Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3874160faf
|
@ -495,7 +495,7 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
|
|
|
@ -273,8 +273,8 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
||||
r"""
|
||||
Returns a slice of input tensor based on the specified indices based on the field ids. This operation
|
||||
supports looking up embeddings within multi hot and one hot fields simultaneously.
|
||||
Returns a slice of input tensor based on the specified indices and the field ids. This operation
|
||||
supports looking up embeddings using multi hot and one hot fields simultaneously.
|
||||
|
||||
Note:
|
||||
When 'target' is set to 'CPU', this module will use
|
||||
|
@ -282,13 +282,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
|
||||
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
||||
The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and
|
||||
'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
||||
output will be zeros if the sum of absolute weight of the field is zero. This class only
|
||||
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
vocab_size (int): The size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
field_size (int): The field size of the final outputs.
|
||||
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
||||
|
@ -296,7 +296,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
||||
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
||||
feature_num_list (tuple): The accompaniment array in field slice mode.
|
||||
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
||||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
|
@ -410,7 +410,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
|
||||
batch_size = self.shape(input_indices)[0]
|
||||
num_segments = batch_size * self.field_size
|
||||
|
||||
bias = Range(0, num_segments, self.field_size)()
|
||||
bias = self.reshape(bias, (self.field_size, -1))
|
||||
field_ids = self.bias_add(field_ids, bias)
|
||||
|
|
Loading…
Reference in New Issue