remove two context param

This commit is contained in:
jinyaohui 2020-05-12 22:11:50 +08:00
parent 298a784878
commit 391a060f21
37 changed files with 66 additions and 130 deletions

View File

@ -58,7 +58,6 @@ options:
--epoch_size epoch size: N, default is 1 --epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1 --device_num number of used devices: N, default is 1
--device_id device id: N, default is 0 --device_id device id: N, default is 0
--enable_task_sink enable task sink: "true" | "false", default is "true"
--enable_loop_sink enable loop sink: "true" | "false", default is "true" --enable_loop_sink enable loop sink: "true" | "false", default is "true"
--enable_mem_reuse enable memory reuse: "true" | "false", default is "true" --enable_mem_reuse enable memory reuse: "true" | "false", default is "true"
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"

View File

@ -50,7 +50,6 @@ do
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \ --device_num=$RANK_SIZE \
--enable_task_sink="true" \
--enable_loop_sink="true" \ --enable_loop_sink="true" \
--enable_mem_reuse="true" \ --enable_mem_reuse="true" \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \

View File

@ -59,7 +59,6 @@ def run_pretrain():
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.")
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.") parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.") parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
@ -76,8 +75,7 @@ def run_pretrain():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"), context.set_context(enable_loop_sink=(args_opt.enable_loop_sink == "true"),
enable_loop_sink=(args_opt.enable_loop_sink == "true"),
enable_mem_reuse=(args_opt.enable_mem_reuse == "true")) enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
context.set_context(reserve_class_name_in_scope=False) context.set_context(reserve_class_name_in_scope=False)

View File

@ -29,7 +29,6 @@ python run_pretrain.py \
--distribute="false" \ --distribute="false" \
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--enable_task_sink="true" \
--enable_loop_sink="true" \ --enable_loop_sink="true" \
--enable_mem_reuse="true" \ --enable_mem_reuse="true" \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \

View File

@ -70,7 +70,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -34,7 +34,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -54,7 +54,6 @@ rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1 run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -46,7 +46,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -49,7 +49,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -39,7 +39,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -42,7 +42,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -71,7 +71,7 @@ if __name__ == '__main__':
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
config = ConfigSSD() config = ConfigSSD()
prefix = "ssd_eval.mindrecord" prefix = "ssd_eval.mindrecord"

View File

@ -93,7 +93,7 @@ def main():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num

View File

@ -64,7 +64,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -82,7 +82,7 @@ if __name__ == '__main__':
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
# It will generate mindrecord file in args_opt.mindrecord_dir, # It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num. # and the file name is yolo.mindrecord0, 1, ... file_num.

View File

@ -85,7 +85,7 @@ def main():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()

View File

@ -115,12 +115,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.")
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("set_task_sink_flag", &mindspore::MsContext::set_enable_task_sink, "Set enable task sink.")
.def("get_task_sink_flag", &mindspore::MsContext::enable_task_sink, "Get whether to enable task sink.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_ir_fusion_flag", &mindspore::MsContext::ir_fusion_flag, "Get whether to enable ir fusion.")
.def("set_ir_fusion_flag", &mindspore::MsContext::set_ir_fusion_flag, "Set whether to enable ir fusion.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.") "Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,

View File

@ -62,7 +62,6 @@ class MsContext {
bool enable_pynative_infer() const { return enable_pynative_infer_; } bool enable_pynative_infer() const { return enable_pynative_infer_; }
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
void set_enable_task_sink(bool enable_task_sink) { enable_task_sink_ = enable_task_sink; }
bool enable_task_sink() const { return enable_task_sink_; } bool enable_task_sink() const { return enable_task_sink_; }
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
@ -90,7 +89,6 @@ class MsContext {
bool enable_hccl() const { return enable_hccl_; } bool enable_hccl() const { return enable_hccl_; }
bool PynativeInitGe(); bool PynativeInitGe();
void set_ir_fusion_flag(bool ir_fusion_flag) { ir_fusion_flag_ = ir_fusion_flag; }
bool ir_fusion_flag() const { return ir_fusion_flag_; } bool ir_fusion_flag() const { return ir_fusion_flag_; }
void set_loop_sink_flag(bool loop_sink_flag) { enable_loop_sink_ = loop_sink_flag; } void set_loop_sink_flag(bool loop_sink_flag) { enable_loop_sink_ = loop_sink_flag; }

View File

@ -142,15 +142,6 @@ class _Context:
raise ValueError("Context handle is none in context!!!") raise ValueError("Context handle is none in context!!!")
return value return value
# For Ascend task sink mode execution
@property
def enable_task_sink(self):
return self._context_handle.get_task_sink_flag()
@enable_task_sink.setter
def enable_task_sink(self, task_sink):
self._context_handle.set_task_sink_flag(task_sink)
@property @property
def mode(self): def mode(self):
return self._context_handle.get_execution_mode() return self._context_handle.get_execution_mode()
@ -224,14 +215,6 @@ class _Context:
if not success: if not success:
raise RuntimeError("Device id set failed!!!") raise RuntimeError("Device id set failed!!!")
@property
def enable_ir_fusion(self):
return self._context_handle.get_ir_fusion_flag()
@enable_ir_fusion.setter
def enable_ir_fusion(self, enable_ir_fusion):
self._context_handle.set_ir_fusion_flag(enable_ir_fusion)
@property @property
def enable_loop_sink(self): def enable_loop_sink(self):
return self._context_handle.get_loop_sink_flag() return self._context_handle.get_loop_sink_flag()
@ -485,11 +468,9 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context() _reset_auto_parallel_context()
@args_type_check(mode=int, precompile_only=bool, device_target=str, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
device_id=int, enable_ir_fusion=bool, save_graphs=bool, save_graphs_path=str, enable_loop_sink=bool, enable_mem_reuse=bool, save_ms_model=bool,
enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool, save_ms_model_path=str, enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str,
enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
enable_reduce_precision=bool, graph_memory_max_size=str, enable_reduce_precision=bool, graph_memory_max_size=str,
variable_memory_max_size=str, enable_profiling=bool, profiling_options=str) variable_memory_max_size=str, enable_profiling=bool, profiling_options=str)
def set_context(**kwargs): def set_context(**kwargs):
@ -517,10 +498,8 @@ def set_context(**kwargs):
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
device_id (int): Id of target device, the value must be in [0, device_num_per_host-1], device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should no more than 4096. Default: 0. while device_num_per_host should no more than 4096. Default: 0.
enable_ir_fusion (bool): Whether to enable ir fusion. Default: True.
save_graphs (bool): Whether to save graphs. Default: False. save_graphs (bool): Whether to save graphs. Default: False.
enable_loop_sink (bool): Whether to enable loop sink. Default: True. enable_loop_sink (bool): Whether to enable loop sink. Default: True.
enable_task_sink (bool): Whether to enable task sink. Default: True.
enable_mem_reuse (bool): Whether to enable memory reuse. Default: True. enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
save_ms_model (bool): Whether to save lite model converted by graph. Default: False. save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
save_ms_model_path (str): Path to save converted lite model. Default: "." save_ms_model_path (str): Path to save converted lite model. Default: "."
@ -559,7 +538,6 @@ def set_context(**kwargs):
>>> context.set_context(device_target="Ascend") >>> context.set_context(device_target="Ascend")
>>> context.set_context(device_id=0) >>> context.set_context(device_id=0)
>>> context.set_context(save_graphs=True, save_graphs_path="./model.ms") >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
>>> context.set_context(enable_task_sink=True)
>>> context.set_context(enable_mem_reuse=True) >>> context.set_context(enable_mem_reuse=True)
>>> context.set_context(enable_reduce_precision=True) >>> context.set_context(enable_reduce_precision=True)
>>> context.set_context(save_ms_model=True, save_ms_model_path=".") >>> context.set_context(save_ms_model=True, save_ms_model_path=".")

View File

@ -33,9 +33,7 @@ def setup_module():
global rank_id global rank_id
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, context.set_context(device_id=device_id)
device_id=device_id)
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
distributedTool.init() distributedTool.init()
device_num = distributedTool.get_group_size() device_num = distributedTool.get_group_size()
@ -86,15 +84,15 @@ class DataGenerator():
return data return data
def input_data(self, shape): def input_data(self, shape):
data = (self.generate_data(shape)*2).astype(np.float32) data = (self.generate_data(shape) * 2).astype(np.float32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])
def label_data(self, shape, classes): def label_data(self, shape, classes):
data = (self.generate_data(shape)*(classes-1)).astype(np.int32) data = (self.generate_data(shape) * (classes - 1)).astype(np.int32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])

View File

@ -37,7 +37,7 @@ device_id = int(os.getenv('DEVICE_ID'))
rank_id = 0 rank_id = 0
embed = 128 embed = 128
classes = 32 classes = 32
batch_size = 32*2 batch_size = 32 * 2
MatmulParamShape = (classes, embed) MatmulParamShape = (classes, embed)
@ -46,9 +46,7 @@ def setup_module():
global rank_id global rank_id
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, context.set_context(device_id=device_id)
device_id=device_id)
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
distributedTool.init() distributedTool.init()
rank_id = distributedTool.get_rank() rank_id = distributedTool.get_rank()
@ -77,20 +75,20 @@ class DataGenerator():
def generate_data(self, shape): def generate_data(self, shape):
size = np.cumprod(shape)[-1] size = np.cumprod(shape)[-1]
num_range = min(size, 1000) num_range = min(size, 1000)
data = (np.arange(0, size) % num_range)/num_range data = (np.arange(0, size) % num_range) / num_range
data = np.reshape(data, shape) data = np.reshape(data, shape)
return data return data
def input_data(self, shape): def input_data(self, shape):
data = (self.generate_data(shape)*0.1).astype(np.float32) data = (self.generate_data(shape) * 0.1).astype(np.float32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])
def label_data(self, shape, embed): def label_data(self, shape, embed):
data = (self.generate_data(shape)*(embed-1)).astype(np.int32) data = (self.generate_data(shape) * (embed - 1)).astype(np.int32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])
@ -141,7 +139,7 @@ class SoftmaxCrossEntropyExpand(Cell):
def __init__(self, sparse=False, stra_list=[]): def __init__(self, sparse=False, stra_list=[]):
super(SoftmaxCrossEntropyExpand, self).__init__() super(SoftmaxCrossEntropyExpand, self).__init__()
if len(stra_list) < 11: if len(stra_list) < 11:
stra_list = [None]*11 stra_list = [None] * 11
self.exp = P.Exp() self.exp = P.Exp()
self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy=stra_list[1]) self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy=stra_list[1])
self.onehot = P.OneHot().set_strategy(strategy=stra_list[2]) self.onehot = P.OneHot().set_strategy(strategy=stra_list[2])

View File

@ -31,8 +31,7 @@ from mindspore.train.callback import Callback
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID'))) context.set_context(device_id=int(os.getenv('DEVICE_ID')))
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
init() init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
@ -316,14 +315,14 @@ class DataGenerator():
def input_data(self, shape): def input_data(self, shape):
data = (self.generate_data(shape)).astype(np.float32) data = (self.generate_data(shape)).astype(np.float32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])
def label_data(self, shape): def label_data(self, shape):
data = (self.generate_data(shape)*1000/np.prod(shape)).astype(np.int32) data = (self.generate_data(shape) * 1000 / np.prod(shape)).astype(np.int32)
stra = [1]*len(shape) stra = [1] * len(shape)
stra[0] = device_num stra[0] = device_num
datas = self.get_parallel_blocks(data, stra) datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id]) return Tensor(data), Tensor(datas[rank_id])
@ -378,8 +377,8 @@ def test_train_feed(num_classes=8192):
set_algo_parameters(elementwise_op_strategy_follow=True) set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback() parallel_callback = ModelCallback()
dataGen = DataGenerator() dataGen = DataGenerator()
input_full, input_part = dataGen.input_data((32*2, 3, 224, 224)) input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
label_full, label_part = dataGen.label_data((32*2,)) label_full, label_part = dataGen.label_data((32 * 2,))
dataset = Dataset(input_part, label_part) dataset = Dataset(input_part, label_part)
net = resnet50(num_classes) net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True) loss = SoftmaxCrossEntropyExpand(sparse=True)
@ -398,8 +397,8 @@ def test_train_feed2(num_classes=1001):
set_algo_parameters(elementwise_op_strategy_follow=True) set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback() parallel_callback = ModelCallback()
dataGen = DataGenerator() dataGen = DataGenerator()
input_full, input_part = dataGen.input_data((32*2, 3, 224, 224)) input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
label_full, label_part = dataGen.label_data((32*2,)) label_full, label_part = dataGen.label_data((32 * 2,))
dataset = Dataset(input_part, label_part) dataset = Dataset(input_part, label_part)
net = resnet50(num_classes) net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True) loss = SoftmaxCrossEntropyExpand(sparse=True)

View File

@ -14,17 +14,14 @@
# ============================================================================ # ============================================================================
""" test_multigraph_sink """ """ test_multigraph_sink """
import pytest import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common import ms_function from mindspore.common import ms_function
from mindspore.ops import operations as P
def setup_module(module): def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
c1 = Tensor([2], mstype.int32) c1 = Tensor([2], mstype.int32)
@ -208,4 +205,3 @@ def test_while_in_while_in_while():
output = while_in_while_in_while(c1, c2, c3) output = while_in_while_in_while(c1, c2, c3)
expect = Tensor([2534], mstype.int32) expect = Tensor([2534], mstype.int32)
assert output == expect assert output == expect

View File

@ -31,7 +31,6 @@ def t1_while(x, y, z):
def test_net(): def test_net():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True)
c1 = Tensor([2], mstype.int32) c1 = Tensor([2], mstype.int32)
c2 = Tensor([14], mstype.int32) c2 = Tensor([14], mstype.int32)
c3 = Tensor([1], mstype.int32) c3 = Tensor([1], mstype.int32)

View File

@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
context.set_context(device_target="Ascend", enable_task_sink=True) context.set_context(device_target="Ascend")
input_channel = 2048 input_channel = 2048
output_channel = 512 output_channel = 512

View File

@ -53,7 +53,7 @@ device_id = int(os.getenv('DEVICE_ID'))
data_home = args_opt.dataset_path data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -36,7 +36,6 @@ random.seed(1)
np.random.seed(1) np.random.seed(1)
de.config.set_seed(1) de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
@ -54,7 +53,7 @@ device_id = int(os.getenv('DEVICE_ID'))
data_home = args_opt.dataset_path data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False) context.set_context(enable_mem_reuse=False)

View File

@ -127,7 +127,6 @@ class ModelCallback(Callback):
def test_bert_tdt(): def test_bert_tdt():
"""test bert tdt""" """test bert tdt"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
ds = me_de_train_dataset() ds = me_de_train_dataset()

View File

@ -15,14 +15,10 @@
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True)
class Net(nn.Cell): class Net(nn.Cell):

View File

@ -21,7 +21,6 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
import mindspore.context as context import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", impl_type="tbe") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", impl_type="tbe")
context.set_context(enable_task_sink=True)
class Adam: class Adam:

View File

@ -15,27 +15,17 @@
import os import os
import numpy as np import numpy as np
from resnet_torch import resnet50 from resnet_torch import resnet50
from mindspore.train.callback import Callback
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import save, load, save_checkpoint, load_checkpoint,\ from mindspore.train.serialization import save, load, _check_filedir_or_create, _chg_model_file_name_if_same_exist, \
load_param_into_net, _exec_save_checkpoint,\
_check_filedir_or_create, _chg_model_file_name_if_same_exist, \
_read_file_last_line, context, export _read_file_last_line, context, export
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_loop_sink=True)
enable_task_sink=True, enable_loop_sink=True, enable_ir_fusion=True)
def test_resnet50_export(batch_size=1, num_classes=5): def test_resnet50_export(batch_size=1, num_classes=5):
context.set_context(enable_ir_fusion=False)
input_np = np.random.uniform(0.0, 1.0, size=[batch_size, 3, 224, 224]).astype(np.float32) input_np = np.random.uniform(0.0, 1.0, size=[batch_size, 3, 224, 224]).astype(np.float32)
net = resnet50(batch_size, num_classes) net = resnet50(batch_size, num_classes)
#param_dict = load_checkpoint("./resnet50-1_103.ckpt") # param_dict = load_checkpoint("./resnet50-1_103.ckpt")
#load_param_into_net(net, param_dict) # load_param_into_net(net, param_dict)
export(net, Tensor(input_np), file_name="./me_resnet50.pb", file_format="GEIR") export(net, Tensor(input_np), file_name="./me_resnet50.pb", file_format="GEIR")

View File

@ -32,6 +32,7 @@ from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from resnet import resnet50 from resnet import resnet50
import random import random
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
ds.config.set_seed(1) ds.config.set_seed(1)
@ -53,7 +54,7 @@ device_id = int(os.getenv('DEVICE_ID'))
data_home = args_opt.dataset_path data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -137,7 +137,7 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
os.system("mkdir " + str(device_id)) os.system("mkdir " + str(device_id))
os.chdir(str(device_id)) os.chdir(str(device_id))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -159,7 +159,7 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
def eval(batch_size, num_classes): def eval(batch_size, num_classes):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=0) context.set_context(device_id=0)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)

View File

@ -24,8 +24,7 @@ import mindspore.common.dtype as mstype
import os import os
import numpy as np import numpy as np
import mindspore.ops.functional as F import mindspore.ops.functional as F
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
@ -34,8 +33,6 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from resnet import resnet50 from resnet import resnet50
import random import random
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from multiprocessing import Pool
import time
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
@ -150,7 +147,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
os.chdir(str(device_id)) os.chdir(str(device_id))
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False) device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
@ -206,9 +203,9 @@ def test_resnet_cifar_8p():
loss = 0.0 loss = 0.0
for i in range(device_num): for i in range(device_num):
loss += q.get() loss += q.get()
loss = loss/device_num loss = loss / device_num
for i in range(device_num): for i in range(device_num):
os.system("rm -rf " + str(i)) os.system("rm -rf " + str(i))
print("End training...") print("End training...")
assert(loss < 2.0) assert (loss < 2.0)

View File

@ -22,7 +22,6 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore import context from mindspore import context
import os
import re import re
import mindspore.ops.functional as F import mindspore.ops.functional as F
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
@ -32,38 +31,43 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id= 0) context.set_context(device_id=0)
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
init() init()
def weight_variable(shape, factor=0.1): def weight_variable(shape, factor=0.1):
return TruncatedNormal(0.02) return TruncatedNormal(0.02)
def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 3x3 kernel size.""" """Get a conv2d layer with 3x3 kernel size."""
init_value = weight_variable((out_channels, in_channels, 3, 3)) init_value = weight_variable((out_channels, in_channels, 3, 3))
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 1x1 kernel size.""" """Get a conv2d layer with 1x1 kernel size."""
init_value = weight_variable((out_channels, in_channels, 1, 1)) init_value = weight_variable((out_channels, in_channels, 1, 1))
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 7x7 kernel size.""" """Get a conv2d layer with 7x7 kernel size."""
init_value = weight_variable((out_channels, in_channels, 7, 7)) init_value = weight_variable((out_channels, in_channels, 7, 7))
return nn.Conv2d(in_channels, out_channels, return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _fused_bn(channels, momentum=0.9): def _fused_bn(channels, momentum=0.9):
"""Get a fused batchnorm""" """Get a fused batchnorm"""
init_weight = weight_variable((channels,)) init_weight = weight_variable((channels,))
init_bias = weight_variable((channels,)) init_bias = weight_variable((channels,))
return nn.BatchNorm2d(channels, momentum=momentum) return nn.BatchNorm2d(channels, momentum=momentum)
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):
expansion = 4 expansion = 4
@ -128,7 +132,7 @@ class ResNet(nn.Cell):
layer_nums, layer_nums,
in_channels, in_channels,
out_channels, out_channels,
strides=[1,2,2,2], strides=[1, 2, 2, 2],
num_classes=100): num_classes=100):
super(ResNet, self).__init__() super(ResNet, self).__init__()
@ -211,6 +215,7 @@ def resnet50(class_num=10):
[2, 2, 2, 1], [2, 2, 2, 1],
class_num) class_num)
class SoftmaxCrossEntropyExpand(_Loss): class SoftmaxCrossEntropyExpand(_Loss):
def __init__(self, sparse=False): def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__() super(SoftmaxCrossEntropyExpand, self).__init__()
@ -304,7 +309,7 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
return allreduce_fusion_dict return allreduce_fusion_dict
def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
@ -476,7 +481,7 @@ def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #10485
cost_model_context.reset_cost_model_context() cost_model_context.reset_cost_model_context()
def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
@ -649,7 +654,7 @@ def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #10485
cost_model_context.reset_cost_model_context() cost_model_context.reset_cost_model_context()
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
dev_num = 8 dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
@ -668,7 +673,7 @@ def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576
model.train(5, dataset, dataset_sink_mode=False) model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_strategy(model._train_network) strategies = _executor._get_strategy(model._train_network)
for (k, v) in strategies.items(): for (k, v) in strategies.items():
if re.search('Conv2D-op', k ) is not None: if re.search('Conv2D-op', k) is not None:
assert v[0][0] == dev_num assert v[0][0] == dev_num
elif re.search('MatMul-op', k) is not None: elif re.search('MatMul-op', k) is not None:
assert v == [[1, 1], [dev_num, 1]] assert v == [[1, 1], [dev_num, 1]]

View File

@ -64,7 +64,6 @@ parser.add_argument('--path', default='./lenet_model.ms', type=str, help='model
if __name__ == '__main__': if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True)
print("test lenet predict start") print("test lenet predict start")
seed = 0 seed = 0

View File

@ -13,19 +13,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test_multigraph_sink """ """ test_multigraph_sink """
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common import ms_function from mindspore.common import ms_function
from mindspore.ops import operations as P
def setup_module(module): def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend") context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend")
context.set_context(enable_task_sink = True, device_id = 0) context.set_context(device_id=0)
c1 = Tensor([2], mstype.int32) c1 = Tensor([2], mstype.int32)
@ -86,6 +82,8 @@ def while_by_while(x, y, z):
x = x + 1 x = x + 1
x = x + 1 x = x + 1
return x return x
@ms_function @ms_function
def while_in_while(x, y, z): def while_in_while(x, y, z):
out = c4 out = c4
@ -98,6 +96,7 @@ def while_in_while(x, y, z):
out = out + x out = out + x
return out return out
def test_simple_if(): def test_simple_if():
output = simple_if(c1, c2, c3) output = simple_if(c1, c2, c3)
expect = Tensor([6], mstype.int32) expect = Tensor([6], mstype.int32)
@ -127,7 +126,8 @@ def test_while_by_while():
expect = Tensor([28], mstype.int32) expect = Tensor([28], mstype.int32)
assert output == expect assert output == expect
def test_while_in_while(): def test_while_in_while():
output = while_in_while(c1, c2, c3) output = while_in_while(c1, c2, c3)
expect = Tensor([1274], mstype.int32) expect = Tensor([1274], mstype.int32)
assert output == expect assert output == expect