forked from mindspore-Ecosystem/mindspore
fix deeptext mindir export fail
This commit is contained in:
parent
14cf33a6df
commit
69dacd458b
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16_Infer
|
||||
|
@ -40,6 +41,10 @@ if __name__ == '__main__':
|
|||
|
||||
load_param_into_net(net, param_dict_new)
|
||||
|
||||
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
if device_type == "Ascend":
|
||||
net.to_float(mstype.float16)
|
||||
|
||||
img_data = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float32)
|
||||
|
||||
export(net, img_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
Loading…
Reference in New Issue