Change export interface
This commit is contained in:
parent
6c4b4f91d2
commit
73325e0f01
|
@ -396,13 +396,13 @@ void ExecutorPy::GetGeBackendPolicy() const {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsPhaseExportGeir(const std::string &phase_s) {
|
||||
auto phase_to_export = "export.geir";
|
||||
bool IsPhaseExportAir(const std::string &phase_s) {
|
||||
auto phase_to_export = "export.air";
|
||||
return phase_s.rfind(phase_to_export) != std::string::npos;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
||||
bool is_geir = IsPhaseExportGeir(phase_s);
|
||||
bool is_air = IsPhaseExportAir(phase_s);
|
||||
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
|
@ -419,7 +419,7 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin
|
|||
}
|
||||
#endif
|
||||
|
||||
if (use_vm && backend != "ge" && !is_geir) {
|
||||
if (use_vm && backend != "ge" && !is_air) {
|
||||
// Create backend and session
|
||||
auto backend_ptr = compile::CreateBackend();
|
||||
// Connect session to debugger
|
||||
|
@ -938,8 +938,9 @@ void FinalizeHccl() {
|
|||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
|
||||
#if (ENABLE_GE || ENABLE_D)
|
||||
ExportDFGraph(file_name, phase);
|
||||
#else
|
||||
MS_EXCEPTION(ValueError) << "Only MindSpore with Ascend backend support exporting file in 'AIR' format.";
|
||||
#endif
|
||||
MS_LOG(WARNING) << "In ut test no export_graph";
|
||||
}
|
||||
|
||||
void ReleaseGeTsd() {
|
||||
|
|
|
@ -515,7 +515,7 @@ class _Executor:
|
|||
graph_id (str): id of graph to be exported
|
||||
"""
|
||||
from .._c_expression import export_graph
|
||||
export_graph(file_name, 'GEIR', graph_id)
|
||||
export_graph(file_name, 'AIR', graph_id)
|
||||
|
||||
def fetch_info_for_quant_export(self, exec_id):
|
||||
"""Get graph proto from pipeline."""
|
||||
|
|
|
@ -435,9 +435,9 @@ class ExportToQuantInferNetwork:
|
|||
return network
|
||||
|
||||
|
||||
def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='GEIR'):
|
||||
def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='AIR'):
|
||||
"""
|
||||
Exports MindSpore quantization predict model to deploy with GEIR.
|
||||
Exports MindSpore quantization predict model to deploy with AIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
|
@ -445,17 +445,17 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
|
|||
file_name (str): File name of model to export.
|
||||
mean (int): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'MINDIR' format for exported
|
||||
quantization aware model. Default: 'GEIR'.
|
||||
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported
|
||||
quantization aware model. Default: 'AIR'.
|
||||
|
||||
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
Ascend model.
|
||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||
for MindSpore models.
|
||||
Recommended suffix for output file is '.mindir'.
|
||||
"""
|
||||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['GEIR', 'MINDIR']
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
|
||||
mean = validator.check_type("mean", mean, (int, float))
|
||||
std_dev = validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
|
|
@ -445,7 +445,7 @@ def _fill_param_into_net(net, parameter_list):
|
|||
load_param_into_net(net, parameter_dict)
|
||||
|
||||
|
||||
def export(net, *inputs, file_name, file_format='GEIR'):
|
||||
def export(net, *inputs, file_name, file_format='AIR'):
|
||||
"""
|
||||
Exports MindSpore predict model to file in specified format.
|
||||
|
||||
|
@ -453,11 +453,12 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|||
net (Cell): MindSpore network.
|
||||
inputs (Tensor): Inputs of the `net`.
|
||||
file_name (str): File name of model to export.
|
||||
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'MINDIR' format for exported model.
|
||||
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
|
||||
|
||||
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
Ascend model.
|
||||
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
|
||||
Recommended suffix for output file is '.air'.
|
||||
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
||||
Recommended suffix for output file is '.onnx'.
|
||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||
for MindSpore models.
|
||||
Recommended suffix for output file is '.mindir'.
|
||||
|
@ -465,7 +466,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||
check_input_data(*inputs, data_class=Tensor)
|
||||
|
||||
supported_formats = ['GEIR', 'ONNX', 'MINDIR']
|
||||
if file_format == 'GEIR':
|
||||
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
||||
file_format = 'AIR'
|
||||
|
||||
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
|
||||
# switch network mode to infer when it is training
|
||||
|
@ -474,13 +479,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|||
net.set_train(mode=False)
|
||||
# export model
|
||||
net.init_parameters_data()
|
||||
if file_format == 'GEIR':
|
||||
phase_name = 'export.geir'
|
||||
if file_format == 'AIR':
|
||||
phase_name = 'export.air'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||
_executor.export(file_name, graph_id)
|
||||
elif file_format == 'ONNX': # file_format is 'ONNX'
|
||||
# NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
|
||||
# do not change it to other values.
|
||||
phase_name = 'export.onnx'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
||||
|
|
|
@ -108,7 +108,7 @@ python eval.py > eval.log 2>&1 & OR sh run_eval.sh
|
|||
│ ├──config.py // parameter configuration
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export checkpoint files into geir/onnx
|
||||
├── export.py // export checkpoint files into air/onnx
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
@ -133,7 +133,7 @@ Major parameters in train.py and config.py are:
|
|||
--checkpoint_path: The absolute full path to the checkpoint file saved
|
||||
after training.
|
||||
--onnx_filename: File name of the onnx model used in export.py.
|
||||
--geir_filename: File name of the geir model used in export.py.
|
||||
--air_filename: File name of the air model used in export.py.
|
||||
```
|
||||
|
||||
|
||||
|
@ -226,7 +226,7 @@ accuracy: {'acc': 0.9217}
|
|||
| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins |
|
||||
| Parameters (M) | 13.0 |
|
||||
| Checkpoint for Fine tuning | 43.07M (.ckpt file) |
|
||||
| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) |
|
||||
| Model for inference | 21.50M (.onnx file), 21.60M(.air file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet |
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into geir and onnx models#################
|
||||
##############export checkpoint file into air and onnx models#################
|
||||
python export.py
|
||||
"""
|
||||
import numpy as np
|
||||
|
@ -33,4 +33,4 @@ if __name__ == '__main__':
|
|||
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32)
|
||||
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
|
||||
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
|
||||
export(net, input_arr, file_name=cfg.air_filename, file_format="AIR")
|
||||
|
|
|
@ -34,5 +34,5 @@ cifar_cfg = edict({
|
|||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt',
|
||||
'onnx_filename': 'googlenet.onnx',
|
||||
'geir_filename': 'googlenet.geir'
|
||||
'air_filename': 'googlenet.air'
|
||||
})
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into geir and onnx models#################
|
||||
##############export checkpoint file into air and onnx models#################
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
@ -37,4 +37,4 @@ if __name__ == '__main__':
|
|||
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32)
|
||||
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
|
||||
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
|
||||
export(net, input_arr, file_name=cfg.air_filename, file_format="AIR")
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
export quantization aware training network to infer `GEIR` backend.
|
||||
export quantization aware training network to infer `AIR` backend.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -53,4 +53,4 @@ if __name__ == "__main__":
|
|||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='GEIR')
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='AIR')
|
||||
|
|
|
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
|||
# export network
|
||||
print("============== Starting export ==============")
|
||||
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR')
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='AIR')
|
||||
print("============== End export ==============")
|
||||
|
|
|
@ -24,4 +24,4 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|||
def test_resnet50_export(batch_size=1, num_classes=5):
|
||||
input_np = np.random.uniform(0.0, 1.0, size=[batch_size, 3, 224, 224]).astype(np.float32)
|
||||
net = resnet50(batch_size, num_classes)
|
||||
export(net, Tensor(input_np), file_name="./me_resnet50.pb", file_format="GEIR")
|
||||
export(net, Tensor(input_np), file_name="./me_resnet50.pb", file_format="AIR")
|
||||
|
|
|
@ -87,8 +87,12 @@ def test_save_graph():
|
|||
x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
out_put = net(x, y)
|
||||
_save_graph(network=net, file_name="net-graph.meta")
|
||||
output_file = "net-graph.meta"
|
||||
_save_graph(network=net, file_name=output_file)
|
||||
out_me_list.append(out_put)
|
||||
assert os.path.exists(output_file)
|
||||
os.chmod(output_file, stat.S_IWRITE)
|
||||
os.remove(output_file)
|
||||
|
||||
|
||||
def test_save_checkpoint():
|
||||
|
@ -318,7 +322,8 @@ class MYNET(nn.Cell):
|
|||
def test_export():
|
||||
net = MYNET()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
|
||||
with pytest.raises(ValueError):
|
||||
export(net, input_data, file_name="./me_export.pb", file_format="AIR")
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
|
|
Loading…
Reference in New Issue