!9085 Fix MultiFieldEmbedding Doc Error

From: @huangxinjing
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-27 19:02:30 +08:00 committed by Gitee
commit 3874160faf
2 changed files with 7 additions and 8 deletions

View File

@ -495,7 +495,7 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());

View File

@ -273,8 +273,8 @@ class EmbeddingLookup(Cell):
class MultiFieldEmbeddingLookup(EmbeddingLookup):
r"""
Returns a slice of input tensor based on the specified indices based on the field ids. This operation
supports looking up embeddings within multi hot and one hot fields simultaneously.
Returns a slice of input tensor based on the specified indices and the field ids. This operation
supports looking up embeddings using multi hot and one hot fields simultaneously.
Note:
When 'target' is set to 'CPU', this module will use
@ -282,13 +282,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and
'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
output will be zeros if the sum of absolute weight of the field is zero. This class only
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
Args:
vocab_size (int): Size of the dictionary of embeddings.
vocab_size (int): The size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
field_size (int): The field size of the final outputs.
param_init (str): The initialize way of embedding table. Default: 'normal'.
@ -296,7 +296,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
feature_num_list (tuple): The accompaniment array in field slice mode.
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
@ -410,7 +410,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
batch_size = self.shape(input_indices)[0]
num_segments = batch_size * self.field_size
bias = Range(0, num_segments, self.field_size)()
bias = self.reshape(bias, (self.field_size, -1))
field_ids = self.bias_add(field_ids, bias)