forked from mindspore-Ecosystem/mindspore
!7234 update mass gpu support.
Merge pull request !7234 from linqingke/mass
This commit is contained in:
commit
2b56b87734
|
@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
from mindspore import context
|
|
||||||
|
|
||||||
from src.dataset import load_dataset
|
from src.dataset import load_dataset
|
||||||
from .transformer_for_infer import TransformerInferModel
|
from .transformer_for_infer import TransformerInferModel
|
||||||
from .transformer_for_train import TransformerTraining
|
from .transformer_for_train import TransformerTraining
|
||||||
from ..utils.load_weights import load_infer_weights
|
from ..utils.load_weights import load_infer_weights
|
||||||
|
|
||||||
context.set_context(
|
|
||||||
mode=context.GRAPH_MODE,
|
|
||||||
save_graphs=False,
|
|
||||||
device_target="Ascend",
|
|
||||||
reserve_class_name_in_scope=False)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerInferCell(nn.Cell):
|
class TransformerInferCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Encapsulation class of transformer network infer.
|
Encapsulation class of transformer network infer.
|
||||||
|
|
Loading…
Reference in New Issue