forked from mindspore-Ecosystem/mindspore
!25853 remove cache for embeddinglookup layer
Merge pull request !25853 from fangzehua/remove_cache_embedding
This commit is contained in:
commit
ea5ff98bf7
|
@ -181,8 +181,8 @@ class EmbeddingLookup(Cell):
|
|||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
|
||||
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
|
||||
In addition, it should be noted that it will cost the 'DEVICE'
|
||||
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
||||
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
||||
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
||||
|
||||
Inputs:
|
||||
|
@ -303,35 +303,13 @@ class EmbeddingLookup(Cell):
|
|||
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
|
||||
"but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
|
||||
if self.cache_enable and not enable_ps:
|
||||
if parallel_mode != ParallelMode.STAND_ALONE:
|
||||
raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.")
|
||||
self._set_cache_enable()
|
||||
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
||||
self.embedding_table.unique = self.forward_unique
|
||||
self.max_norm = max_norm
|
||||
if self.max_norm is not None:
|
||||
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
||||
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
||||
|
||||
def _set_cache_enable(self):
|
||||
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
||||
if self.target != 'DEVICE':
|
||||
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
|
||||
f"when 'target' is 'DEVICE', but got 'target': {self.target}")
|
||||
if not self.sparse:
|
||||
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
|
||||
f"when 'sparse' is true, but got 'sparse': {self.sparse}.")
|
||||
if context.get_context("device_target") != 'Ascend':
|
||||
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
|
||||
f"when device target is 'Ascend', but got {context.get_context('device_target')}.")
|
||||
|
||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||
self.forward_unique = True
|
||||
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
self.embedding_table.cache_enable = self.cache_enable
|
||||
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
|
||||
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def _process_vocab_cache(self, slice_mode):
|
||||
"""PS embeddingLookup cache check and process."""
|
||||
self.cache_enable = False
|
||||
|
|
Loading…
Reference in New Issue