forked from mindspore-Ecosystem/mindspore
add check for embeding look up
This commit is contained in:
parent
9481bd6e23
commit
cd7ff5e60b
|
@ -5017,7 +5017,7 @@ class Sort(PrimitiveWithInfer):
|
|||
return x_dtype, mstype.tensor_type(mstype.int32)
|
||||
|
||||
|
||||
class EmbeddingLookup(PrimitiveWithInfer):
|
||||
class EmbeddingLookup(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a slice of input tensor based on the specified indices.
|
||||
|
||||
|
@ -5063,28 +5063,13 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||
outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, offset):
|
||||
def __check__(self, params, indices, offset):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||
params_shp = params['shape']
|
||||
if len(params_shp) > 2:
|
||||
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
||||
out_shape = indices['shape'] + params_shp[1:]
|
||||
if 'max_shape' in indices:
|
||||
out_max_shape = indices['max_shape'] + params_shp[1:]
|
||||
else:
|
||||
out_max_shape = out_shape
|
||||
if 'min_shape' in indices:
|
||||
out_min_shape = indices['min_shape'] + params_shp[1:]
|
||||
else:
|
||||
out_min_shape = out_shape
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None,
|
||||
'max_shape': out_max_shape,
|
||||
'min_shape': out_min_shape}
|
||||
return out
|
||||
|
||||
|
||||
class GatherD(PrimitiveWithInfer):
|
||||
|
|
|
@ -243,7 +243,7 @@ def test():
|
|||
args = parse_args()
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False, device_id=devid)
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.log_path,
|
||||
|
|
|
@ -134,7 +134,7 @@ def conver_training_shape(args):
|
|||
def network_init(args):
|
||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||
device_target=args.device_target, save_graphs=False, device_id=devid)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
|
|
Loading…
Reference in New Issue