diff --git a/model_zoo/official/nlp/mass/src/transformer/infer_mass.py b/model_zoo/official/nlp/mass/src/transformer/infer_mass.py index 8e0c9296750..5b02f033e95 100644 --- a/model_zoo/official/nlp/mass/src/transformer/infer_mass.py +++ b/model_zoo/official/nlp/mass/src/transformer/infer_mass.py @@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore import context - from src.dataset import load_dataset from .transformer_for_infer import TransformerInferModel from .transformer_for_train import TransformerTraining from ..utils.load_weights import load_infer_weights -context.set_context( - mode=context.GRAPH_MODE, - save_graphs=False, - device_target="Ascend", - reserve_class_name_in_scope=False) - - class TransformerInferCell(nn.Cell): """ Encapsulation class of transformer network infer.