!10267 modify export file for gat,gcn,tinybert,wide_and_deep to support mindir and GPU

From: @yuzhenhua666
Reviewed-by: @linqingke,@c_34
Signed-off-by: @linqingke,@c_34
This commit is contained in:
mindspore-ci-bot 2020-12-21 21:06:53 +08:00 committed by Gitee
commit e5cacf7dec
5 changed files with 68 additions and 63 deletions

View File

@ -35,7 +35,7 @@ args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if __name__ == "__main__": if __name__ == "__main__":
if args.platform != "GPU": if args.device_target != "GPU":
raise ValueError("Only supported GPU now.") raise ValueError("Only supported GPU now.")
net = efficientnet_b0(num_classes=cfg.num_classes, net = efficientnet_b0(num_classes=cfg.num_classes,

View File

@ -12,26 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""export checkpoint file into air models""" """export checkpoint file into air, mindir and onnx models"""
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, export
from mindspore.train.serialization import load_checkpoint, export
from src.gat import GAT from src.gat import GAT
from src.config import GatConfig from src.config import GatConfig
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") parser = argparse.ArgumentParser(description="GAT export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--dataset", type=str, default="cora", choices=["cora", "citeseer"], help="Dataset.")
parser.add_argument("--file_name", type=str, default="gat", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
args = parser.parse_args()
if __name__ == '__main__': context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
parser = argparse.ArgumentParser(description='GAT_export')
parser.add_argument('--ckpt_file', type=str, default='./ckpts/gat.ckpt', help='GAT ckpt file.')
parser.add_argument('--output_file', type=str, default='gat.air', help='GAT output air name.')
parser.add_argument('--dataset', type=str, default='cora', help='GAT dataset name.')
args_opt = parser.parse_args()
if args_opt.dataset == "citeseer": if __name__ == "__main__":
if args.dataset == "citeseer":
feature_size = [1, 3312, 3703] feature_size = [1, 3312, 3703]
biases_size = [1, 3312, 3312] biases_size = [1, 3312, 3312]
num_classes = 6 num_classes = 6
@ -58,7 +63,7 @@ if __name__ == '__main__':
ftr_drop=0.0) ftr_drop=0.0)
gat_net.set_train(False) gat_net.set_train(False)
load_checkpoint(args_opt.ckpt_file, net=gat_net) load_checkpoint(args.ckpt_file, net=gat_net)
gat_net.add_flags_recursive(fp16=True) gat_net.add_flags_recursive(fp16=True)
export(gat_net, Tensor(feature), Tensor(biases), file_name=args_opt.output_file, file_format="AIR") export(gat_net, Tensor(feature), Tensor(biases), file_name=args.file_name, file_format=args.file_format)

View File

@ -16,24 +16,27 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, export
from mindspore.train.serialization import load_checkpoint, export
from src.gcn import GCN from src.gcn import GCN
from src.config import ConfigGCN from src.config import ConfigGCN
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") parser = argparse.ArgumentParser(description="GCN export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--dataset", type=str, default="cora", choices=["cora", "citeseer"], help="Dataset.")
parser.add_argument("--file_name", type=str, default="gcn", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
args = parser.parse_args()
if __name__ == '__main__': context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
parser = argparse.ArgumentParser(description='GCN_export')
parser.add_argument('--ckpt_file', type=str, default='', help='GCN ckpt file.')
parser.add_argument('--output_file', type=str, default='gcn.air', help='GCN output air name.')
parser.add_argument('--dataset', type=str, default='cora', help='GCN dataset name.')
args_opt = parser.parse_args()
if __name__ == "__main__":
config = ConfigGCN() config = ConfigGCN()
if args_opt.dataset == "cora": if args.dataset == "cora":
input_dim = 1433 input_dim = 1433
class_num = 7 class_num = 7
adj = Tensor(np.zeros((2708, 2708), np.float64)) adj = Tensor(np.zeros((2708, 2708), np.float64))
@ -47,7 +50,7 @@ if __name__ == '__main__':
gcn_net = GCN(config, input_dim, class_num) gcn_net = GCN(config, input_dim, class_num)
gcn_net.set_train(False) gcn_net.set_train(False)
load_checkpoint(args_opt.ckpt_file, net=gcn_net) load_checkpoint(args.ckpt_file, net=gcn_net)
gcn_net.add_flags_recursive(fp16=True) gcn_net.add_flags_recursive(fp16=True)
export(gcn_net, adj, feature, file_name=args_opt.output_file, file_format="AIR") export(gcn_net, adj, feature, file_name=args.file_name, file_format=args.file_format)

View File

@ -25,11 +25,17 @@ from src.td_config import td_student_net_cfg
from src.tinybert_model import BertModelCLS from src.tinybert_model import BertModelCLS
parser = argparse.ArgumentParser(description='tinybert task distill') parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument('--ckpt_file', type=str, required=True, help='tinybert ckpt file.') parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument('--output_file', type=str, default='tinybert', help='tinybert output air name.') parser.add_argument("--ckpt_file", type=str, required=True, help="tinybert ckpt file.")
parser.add_argument("--file_name", type=str, default="tinybert", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
DEFAULT_NUM_LABELS = 2 DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128 DEFAULT_SEQ_LENGTH = 128
DEFAULT_BS = 32 DEFAULT_BS = 32
@ -37,8 +43,6 @@ task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128}, "QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}} "MNLI": {"num_labels": 3, "seq_length": 128}}
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Task: class Task:
""" """
Encapsulation class of get the task parameter. Encapsulation class of get the task parameter.
@ -78,4 +82,5 @@ if __name__ == '__main__':
token_type_id = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32)) token_type_id = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32))
input_mask = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32)) input_mask = Tensor(np.zeros((td_student_net_cfg.batch_size, task.seq_length), np.int32))
export(eval_model, input_ids, token_type_id, input_mask, file_name=args.output_file, file_format="AIR") input_data = [input_ids, token_type_id, input_mask]
export(eval_model, *input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -13,48 +13,40 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
##############export checkpoint file into air and onnx models################# ##############export checkpoint file into air, mindir and onnx models#################
""" """
import argparse
import numpy as np import numpy as np
from mindspore import Tensor, nn from mindspore import Tensor, context, load_checkpoint, export, load_param_into_net
from mindspore.ops import operations as P
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.wide_and_deep import WideDeepModel from eval import ModelBuilder
from src.config import WideDeepConfig from src.config import WideDeepConfig
class PredictWithSigmoid(nn.Cell): parser = argparse.ArgumentParser(description="wide_and_deep export")
""" parser.add_argument("--device_id", type=int, default=0, help="Device id")
PredictWithSigmoid parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
""" parser.add_argument("--file_name", type=str, default="wide_and_deep", help="output file name.")
def __init__(self, network): parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
super(PredictWithSigmoid, self).__init__() parser.add_argument("--device_target", type=str, default="Ascend",
self.network = network choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
self.sigmoid = P.Sigmoid() args = parser.parse_args()
def construct(self, batch_ids, batch_wts): context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
logits, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return pred_probs
def get_WideDeep_net(config):
"""
Get network of wide&deep predict model.
"""
WideDeep_net = WideDeepModel(config)
eval_net = PredictWithSigmoid(WideDeep_net)
return eval_net
if __name__ == '__main__': if __name__ == '__main__':
widedeep_config = WideDeepConfig() widedeep_config = WideDeepConfig()
widedeep_config.argparse_init() widedeep_config.argparse_init()
ckpt_path = widedeep_config.ckpt_path
net = get_WideDeep_net(widedeep_config) net_builder = ModelBuilder()
param_dict = load_checkpoint(ckpt_path) _, eval_net = net_builder.get_net(widedeep_config)
load_param_into_net(net, param_dict)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
ids = Tensor(np.ones([widedeep_config.eval_batch_size, widedeep_config.field_size]).astype(np.int32)) ids = Tensor(np.ones([widedeep_config.eval_batch_size, widedeep_config.field_size]).astype(np.int32))
wts = Tensor(np.ones([widedeep_config.eval_batch_size, widedeep_config.field_size]).astype(np.float32)) wts = Tensor(np.ones([widedeep_config.eval_batch_size, widedeep_config.field_size]).astype(np.float32))
input_tensor_list = [ids, wts] label = Tensor(np.ones([widedeep_config.eval_batch_size, 1]).astype(np.float32))
export(net, *input_tensor_list, file_name='wide_and_deep', file_format="ONNX") input_tensor_list = [ids, wts, label]
export(net, *input_tensor_list, file_name='wide_and_deep', file_format="AIR") export(eval_net, *input_tensor_list, file_name=args.file_name, file_format=args.file_format)