update mass gpu network.

This commit is contained in:
linqingke 2020-10-13 14:46:00 +08:00
parent 81d9015ddb
commit 771042a457
1 changed files with 0 additions and 9 deletions

View File

@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context
from src.dataset import load_dataset
from .transformer_for_infer import TransformerInferModel
from .transformer_for_train import TransformerTraining
from ..utils.load_weights import load_infer_weights
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=False)
class TransformerInferCell(nn.Cell):
"""
Encapsulation class of transformer network infer.