diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b8bb8f0ade6..346a8a57f1f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -5017,7 +5017,7 @@ class Sort(PrimitiveWithInfer): return x_dtype, mstype.tensor_type(mstype.int32) -class EmbeddingLookup(PrimitiveWithInfer): +class EmbeddingLookup(PrimitiveWithCheck): """ Returns a slice of input tensor based on the specified indices. @@ -5063,28 +5063,13 @@ class EmbeddingLookup(PrimitiveWithInfer): self.init_prim_io_names(inputs=['params', 'indices', 'offset'], outputs=['output']) - def __infer__(self, params, indices, offset): + def __check__(self, params, indices, offset): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) params_shp = params['shape'] if len(params_shp) > 2: raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp)) - out_shape = indices['shape'] + params_shp[1:] - if 'max_shape' in indices: - out_max_shape = indices['max_shape'] + params_shp[1:] - else: - out_max_shape = out_shape - if 'min_shape' in indices: - out_min_shape = indices['min_shape'] + params_shp[1:] - else: - out_min_shape = out_shape - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None, - 'max_shape': out_max_shape, - 'min_shape': out_min_shape} - return out class GatherD(PrimitiveWithInfer): diff --git a/model_zoo/official/cv/yolov3_darknet53/eval.py b/model_zoo/official/cv/yolov3_darknet53/eval.py index cdb119b3a3e..7e230e68cf0 100644 --- a/model_zoo/official/cv/yolov3_darknet53/eval.py +++ b/model_zoo/official/cv/yolov3_darknet53/eval.py @@ -243,7 +243,7 @@ def test(): args = parse_args() devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False, device_id=devid) # logger args.outputs_dir = os.path.join(args.log_path, diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index eef9f8192e2..5e2eaf89101 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -134,7 +134,7 @@ def conver_training_shape(args): def network_init(args): devid = int(os.getenv('DEVICE_ID', '0')) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target=args.device_target, save_graphs=True, device_id=devid) + device_target=args.device_target, save_graphs=False, device_id=devid) # init distributed if args.is_distributed: if args.device_target == "Ascend":