modify export for centerface, fix yolov4 export bug

This commit is contained in:
yuzhenhua 2020-11-20 11:42:13 +08:00
parent 7689062c7d
commit 6336825d7a
2 changed files with 30 additions and 38 deletions

View File

@ -12,51 +12,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Convert ckpt to air."""
import os
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context import mindspore
from mindspore import Tensor from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.centerface import CenterfaceMobilev2 from src.centerface import CenterfaceMobilev2
from src.config import ConfigCenterface
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) parser = argparse.ArgumentParser(description='centerface export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="centerface.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args()
def save_air(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
"""Save air file"""
print('============= centerface start save air ==================')
parser = argparse.ArgumentParser(description='Convert ckpt to air') if __name__ == '__main__':
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') config = ConfigCenterface()
parser.add_argument('--batch_size', type=int, default=8, help='batch size') net = CenterfaceMobilev2()
args = parser.parse_args() param_dict = load_checkpoint(args.ckpt_file)
network = CenterfaceMobilev2() param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
continue
elif key.startswith('centerface_network.'):
param_dict_new[key[19:]] = values
else:
param_dict_new[key] = values
if os.path.isfile(args.pretrained): load_param_into_net(net, param_dict_new)
param_dict = load_checkpoint(args.pretrained) net.set_train(False)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
continue
elif key.startswith('centerface_network.'):
param_dict_new[key[19:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('load model {} success'.format(args.pretrained))
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data,
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR')
print("export model success.")
if __name__ == "__main__":
save_air()

View File

@ -26,7 +26,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--testing_shape", type=int, default=608, help="test shape") parser.add_argument("--testing_shape", type=int, default=608, help="test shape")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="ssd.air", help="output file name.") parser.add_argument("--file_name", type=str, default="yolov4.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args() args = parser.parse_args()