!25853 remove cache for embeddinglookup layer

Merge pull request !25853 from fangzehua/remove_cache_embedding
This commit is contained in:
i-robot 2021-11-05 01:33:13 +00:00 committed by Gitee
commit ea5ff98bf7
1 changed files with 3 additions and 25 deletions

View File

@ -181,8 +181,8 @@ class EmbeddingLookup(Cell):
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
In addition, it should be noted that it will cost the 'DEVICE'
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
memory, so suggests setting a reasonable value to avoid insufficient memory.
Inputs:
@ -303,35 +303,13 @@ class EmbeddingLookup(Cell):
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
"but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
if self.cache_enable and not enable_ps:
if parallel_mode != ParallelMode.STAND_ALONE:
raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
self.embedding_table.unique = self.forward_unique
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
if self.target != 'DEVICE':
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
f"when 'target' is 'DEVICE', but got 'target': {self.target}")
if not self.sparse:
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
f"when 'sparse' is true, but got 'sparse': {self.sparse}.")
if context.get_context("device_target") != 'Ascend':
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
f"when device target is 'Ascend', but got {context.get_context('device_target')}.")
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
self.unique.add_prim_attr('cache_enable', True)
self.embedding_table.cache_enable = self.cache_enable
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False