From 6aea4ba4a718a9ccc68d6f472625a27a1231be26 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Thu, 26 Nov 2020 21:20:52 +0800 Subject: [PATCH] Fix doc for multi-field embedding --- mindspore/core/abstract/prim_others.cc | 2 +- mindspore/nn/layer/embedding.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 0a4f7b1fa23..46a30b3e6c3 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -495,7 +495,7 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); + CheckArgsSize(op_name, args_spec_list, 1); auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x->shape()); diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index f6e8c8f48c7..df338b8ecc4 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -273,8 +273,8 @@ class EmbeddingLookup(Cell): class MultiFieldEmbeddingLookup(EmbeddingLookup): r""" - Returns a slice of input tensor based on the specified indices based on the field ids. This operation - supports looking up embeddings within multi hot and one hot fields simultaneously. + Returns a slice of input tensor based on the specified indices and the field ids. This operation + supports looking up embeddings using multi hot and one hot fields simultaneously. Note: When 'target' is set to 'CPU', this module will use @@ -282,13 +282,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): specified 'offset = 0' to lookup table. When 'target' is set to 'DEVICE', this module will use P.GatherV2() which specified 'axis = 0' to lookup table. - The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and - `MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final + The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and + 'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final output will be zeros if the sum of absolute weight of the field is zero. This class only supports ['table_row_slice', 'batch_slice' and 'table_column_slice'] Args: - vocab_size (int): Size of the dictionary of embeddings. + vocab_size (int): The size of the dictionary of embeddings. embedding_size (int): The size of each embedding vector. field_size (int): The field size of the final outputs. param_init (str): The initialize way of embedding table. Default: 'normal'. @@ -296,7 +296,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): ['DEVICE', 'CPU']. Default: 'CPU'. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. - feature_num_list (tuple): The accompaniment array in field slice mode. + feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None. Default: None sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. @@ -410,7 +410,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): batch_size = self.shape(input_indices)[0] num_segments = batch_size * self.field_size - bias = Range(0, num_segments, self.field_size)() bias = self.reshape(bias, (self.field_size, -1)) field_ids = self.bias_add(field_ids, bias)