forked from mindspore-Ecosystem/mindspore
!11541 modify transformer export
From: @changzherui Reviewed-by: @guoqi1024,@kingxian Signed-off-by: @guoqi1024
This commit is contained in:
commit
14a6713d08
|
@ -26,7 +26,6 @@ from eval import load_weights
|
|||
|
||||
parser = argparse.ArgumentParser(description='transformer export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="transformer", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
|
@ -43,7 +42,7 @@ if __name__ == '__main__':
|
|||
parameter_dict = load_weights(cfg.model_file)
|
||||
load_param_into_net(tfm_model, parameter_dict)
|
||||
|
||||
source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
|
||||
source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
|
||||
source_ids = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
|
||||
source_mask = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
|
||||
|
||||
export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
Loading…
Reference in New Issue