forked from mindspore-Ecosystem/mindspore
!7234 update mass gpu support.
Merge pull request !7234 from linqingke/mass
This commit is contained in:
commit
2b56b87734
|
@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from mindspore import context
|
||||
|
||||
from src.dataset import load_dataset
|
||||
from .transformer_for_infer import TransformerInferModel
|
||||
from .transformer_for_train import TransformerTraining
|
||||
from ..utils.load_weights import load_infer_weights
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=False)
|
||||
|
||||
|
||||
class TransformerInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of transformer network infer.
|
||||
|
|
Loading…
Reference in New Issue