forked from mindspore-Ecosystem/mindspore
enable_a_m_p
This commit is contained in:
parent
40c20f16f0
commit
6635e42a46
|
@ -44,7 +44,7 @@ def modelarts_pre_process():
|
||||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def train():
|
def train():
|
||||||
device_num = get_device_num()
|
device_num = get_device_num()
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
||||||
device_target=config.device_target, device_id=get_device_id())
|
device_target=config.device_target, device_id=get_device_id())
|
||||||
# init multicards training
|
# init multicards training
|
||||||
config.rank = 0
|
config.rank = 0
|
||||||
|
|
|
@ -39,7 +39,7 @@ context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
|
||||||
if config.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(device_id=dev_id)
|
context.set_context(device_id=dev_id)
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
|
|
||||||
def modelarts_process():
|
def modelarts_process():
|
||||||
config.data_dir = config.data_path
|
config.data_dir = config.data_path
|
||||||
|
|
|
@ -55,8 +55,6 @@ dev_id = get_device_id()
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target,
|
||||||
save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
|
save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
|
||||||
|
|
||||||
if config.device_target == "Ascend":
|
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
|
|
||||||
if config.lr_scheduler == 'cosine_annealing' and config.max_epoch > config.t_max:
|
if config.lr_scheduler == 'cosine_annealing' and config.max_epoch > config.t_max:
|
||||||
config.t_max = config.max_epoch
|
config.t_max = config.max_epoch
|
||||||
|
|
|
@ -29,7 +29,7 @@ from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
||||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
save_graphs_path=".")
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_creator():
|
def test_dataset_creator():
|
||||||
|
|
|
@ -65,7 +65,7 @@ def train():
|
||||||
|
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = get_device_id()
|
device_id = get_device_id()
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=False)
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
init()
|
init()
|
||||||
|
|
|
@ -105,7 +105,7 @@ def run_eval():
|
||||||
datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S"))
|
datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S"))
|
||||||
config.logger = get_logger(config.outputs_dir, config.rank)
|
config.logger = get_logger(config.outputs_dir, config.rank)
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
||||||
config.logger.save_args(config)
|
config.logger.save_args(config)
|
||||||
|
|
||||||
|
|
|
@ -152,7 +152,7 @@ def modelarts_pre_process():
|
||||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def run_train():
|
def run_train():
|
||||||
config = set_default_args(default_config)
|
config = set_default_args(default_config)
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
||||||
if config.is_distributed:
|
if config.is_distributed:
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
|
|
|
@ -109,7 +109,7 @@ def train():
|
||||||
if args.device_target == "CPU":
|
if args.device_target == "CPU":
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
||||||
device_target="Ascend", device_id=get_device_id())
|
device_target="Ascend", device_id=get_device_id())
|
||||||
|
|
||||||
# init multicards training
|
# init multicards training
|
||||||
|
|
|
@ -85,7 +85,7 @@ def train():
|
||||||
if args.device_target == "CPU":
|
if args.device_target == "CPU":
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
|
||||||
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
||||||
# init multicards training
|
# init multicards training
|
||||||
if args.modelArts_mode:
|
if args.modelArts_mode:
|
||||||
|
|
|
@ -130,7 +130,7 @@ def train():
|
||||||
config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
|
config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
|
||||||
config.image_size = list(map(int, config.image_size.split(',')))
|
config.image_size = list(map(int, config.image_size.split(',')))
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=config.device_target, save_graphs=False)
|
device_target=config.device_target, save_graphs=False)
|
||||||
|
|
||||||
if config.device_target == 'Ascend':
|
if config.device_target == 'Ascend':
|
||||||
|
|
|
@ -49,7 +49,6 @@ args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
context.set_context(
|
context.set_context(
|
||||||
mode=context.GRAPH_MODE,
|
mode=context.GRAPH_MODE,
|
||||||
enable_auto_mixed_precision=True,
|
|
||||||
device_target=args.device_target,
|
device_target=args.device_target,
|
||||||
save_graphs=False,
|
save_graphs=False,
|
||||||
device_id=args.device_num)
|
device_id=args.device_num)
|
||||||
|
|
|
@ -189,7 +189,6 @@ if args.is_distributed:
|
||||||
|
|
||||||
context.set_context(
|
context.set_context(
|
||||||
mode=context.GRAPH_MODE,
|
mode=context.GRAPH_MODE,
|
||||||
enable_auto_mixed_precision=True,
|
|
||||||
device_target=args.device_target,
|
device_target=args.device_target,
|
||||||
save_graphs=False,
|
save_graphs=False,
|
||||||
device_id=args.rank)
|
device_id=args.rank)
|
||||||
|
|
|
@ -115,7 +115,7 @@ def train_mobilenetv1():
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
|
|
|
@ -62,7 +62,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if device_target == "Ascend":
|
if device_target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True,
|
gradients_mean=True,
|
||||||
auto_parallel_search_mode="recursive_programming")
|
auto_parallel_search_mode="recursive_programming")
|
||||||
|
|
|
@ -105,7 +105,7 @@ def set_parameter():
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=config.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=config.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
|
|
|
@ -56,8 +56,7 @@ if args_opt.device_target == "Ascend":
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend",
|
device_target="Ascend",
|
||||||
save_graphs=False,
|
save_graphs=False,
|
||||||
device_id=device_id,
|
device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported device target.")
|
raise ValueError("Unsupported device target.")
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
|
|
|
@ -132,8 +132,7 @@ def set_graph_kernel_context(device_target):
|
||||||
def test():
|
def test():
|
||||||
"""test"""
|
"""test"""
|
||||||
set_parameters()
|
set_parameters()
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||||
device_target=config.device_target, save_graphs=False)
|
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
set_graph_kernel_context(config.device_target)
|
set_graph_kernel_context(config.device_target)
|
||||||
|
|
|
@ -106,8 +106,7 @@ class ProgressMonitor(Callback):
|
||||||
|
|
||||||
def set_parameters():
|
def set_parameters():
|
||||||
"""parameters"""
|
"""parameters"""
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||||
device_target=config.device_target, save_graphs=False)
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
init()
|
init()
|
||||||
|
|
|
@ -59,8 +59,7 @@ def train_net():
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = get_device_id()
|
device_id = get_device_id()
|
||||||
device_num = config.device_num
|
device_num = config.device_num
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=device_num,
|
device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
|
|
@ -136,7 +136,7 @@ def run_eval():
|
||||||
|
|
||||||
_enable_graph_kernel = config.device_target == "GPU"
|
_enable_graph_kernel = config.device_target == "GPU"
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
|
||||||
enable_auto_mixed_precision=True, device_target=config.device_target, save_graphs=False)
|
device_target=config.device_target, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit() and config.device_target == "Ascend":
|
if os.getenv('DEVICE_ID', "not_set").isdigit() and config.device_target == "Ascend":
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ def set_graph_kernel_context():
|
||||||
|
|
||||||
def network_init(args):
|
def network_init(args):
|
||||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=args.device_target, save_graphs=False, device_id=devid)
|
device_target=args.device_target, save_graphs=False, device_id=devid)
|
||||||
set_graph_kernel_context()
|
set_graph_kernel_context()
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ from src.util import ShapeRecord
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend", save_graphs=False, device_id=devid)
|
device_target="Ascend", save_graphs=False, device_id=devid)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ def set_default():
|
||||||
config.ann_val_file = os.path.join(config.data_dir, 'annotations/instances_val2017.json')
|
config.ann_val_file = os.path.join(config.data_dir, 'annotations/instances_val2017.json')
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=config.device_target, save_graphs=False, device_id=device_id)
|
device_target=config.device_target, save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
if config.need_profiler:
|
if config.need_profiler:
|
||||||
|
|
|
@ -108,6 +108,7 @@ args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||||
|
|
||||||
if args.is_modelArts:
|
if args.is_modelArts:
|
||||||
args.data_root = os.path.join(args.data_dir, 'train2017')
|
args.data_root = os.path.join(args.data_dir, 'train2017')
|
||||||
|
|
||||||
args.annFile = os.path.join(args.data_dir, 'annotations')
|
args.annFile = os.path.join(args.data_dir, 'annotations')
|
||||||
outputs_dir = os.path.join('/cache', args.ckpt_path)
|
outputs_dir = os.path.join('/cache', args.ckpt_path)
|
||||||
else:
|
else:
|
||||||
|
@ -117,12 +118,13 @@ else:
|
||||||
outputs_dir = args.ckpt_path
|
outputs_dir = args.ckpt_path
|
||||||
|
|
||||||
deviced = int(os.getenv('DEVICE_ID', '0'))
|
deviced = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, device_target=args.device_target,
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||||
save_graphs=False, device_id=deviced)
|
save_graphs=False, device_id=deviced)
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
if args.device_target == "Ascend":
|
if args.device_target == "Ascend":
|
||||||
init()
|
init()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
init("nccl")
|
init("nccl")
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
|
|
|
@ -115,8 +115,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
op_sqrt = P.Sqrt()
|
op_sqrt = P.Sqrt()
|
||||||
scatter_add = P.ScatterAdd(use_locking)
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
assign_m = F.assign(m, op_mul(beta1, m))
|
F.assign(m, op_mul(beta1, m))
|
||||||
assign_v = F.assign(v, op_mul(beta2, v))
|
F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
grad_indices = gradient.indices
|
grad_indices = gradient.indices
|
||||||
grad_value = gradient.values
|
grad_value = gradient.values
|
||||||
|
@ -131,17 +131,15 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
if use_nesterov:
|
if use_nesterov:
|
||||||
m_temp = next_m * _scaler_ten
|
m_temp = next_m * _scaler_ten
|
||||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
F.assign(m, op_mul(beta1, next_m))
|
||||||
div_value = scatter_add(m,
|
div_value = scatter_add(m,
|
||||||
op_mul(grad_indices, _scaler_one),
|
op_mul(grad_indices, _scaler_one),
|
||||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
|
|
||||||
F.control_depend(m_temp, assign_m_nesterov)
|
|
||||||
F.control_depend(assign_m_nesterov, div_value)
|
|
||||||
F.control_depend(param_update, m_recover)
|
|
||||||
else:
|
else:
|
||||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
@ -149,8 +147,7 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
next_param = param - lr_t * param_update
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
F.control_depend(assign_m, next_m)
|
|
||||||
F.control_depend(assign_v, next_v)
|
|
||||||
|
|
||||||
success = F.depend(success, F.assign(param, next_param))
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
success = F.depend(success, F.assign(m, next_m))
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
|
|
@ -119,7 +119,7 @@ def run_transformer_train():
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
context.set_context(reserve_class_name_in_scope=False)
|
||||||
|
|
||||||
if config.device_target == "GPU":
|
if config.device_target == "GPU":
|
||||||
# Enable graph kernel
|
# Enable graph kernel
|
||||||
|
|
|
@ -43,8 +43,8 @@ from model_utils.device_adapter import get_device_id, get_device_num, get_rank_i
|
||||||
mindspore.common.seed.set_seed(1)
|
mindspore.common.seed.set_seed(1)
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
|
||||||
reserve_class_name_in_scope=False, enable_graph_kernel=config.device_target == "GPU")
|
reserve_class_name_in_scope=False, enable_graph_kernel=config.device_target == "GPU")
|
||||||
if config.device_target == 'Ascend':
|
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
if config.device_target != 'GPU' or not config.is_distributed:
|
if config.device_target != 'GPU' or not config.is_distributed:
|
||||||
context.set_context(device_id=get_device_id())
|
context.set_context(device_id=get_device_id())
|
||||||
|
|
||||||
|
|
|
@ -123,8 +123,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
|
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(device_num=device_num,
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
|
|
|
@ -101,8 +101,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
|
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(device_num=device_num,
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
|
|
|
@ -172,7 +172,7 @@ if __name__ == "__main__":
|
||||||
if args.run_distribute:
|
if args.run_distribute:
|
||||||
device_num = int(os.getenv('RANK_SIZE'))
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
|
|
|
@ -65,7 +65,7 @@ def main():
|
||||||
if args.run_distribute:
|
if args.run_distribute:
|
||||||
device_num = int(os.getenv('DEVICE_NUM'))
|
device_num = int(os.getenv('DEVICE_NUM'))
|
||||||
rank_id = int(os.getenv("RANK_ID"))
|
rank_id = int(os.getenv("RANK_ID"))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=args.device_target, device_id=device_id)
|
device_target=args.device_target, device_id=device_id)
|
||||||
init()
|
init()
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
@ -73,7 +73,7 @@ def main():
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=args.device_target, device_id=device_id)
|
device_target=args.device_target, device_id=device_id)
|
||||||
|
|
||||||
# define save checkpoint flag
|
# define save checkpoint flag
|
||||||
|
|
|
@ -63,7 +63,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
|
|
|
@ -50,8 +50,7 @@ if __name__ == '__main__':
|
||||||
if conf.device_target == 'Ascend':
|
if conf.device_target == 'Ascend':
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(
|
context.set_context(
|
||||||
device_id=device_id,
|
device_id=device_id
|
||||||
enable_auto_mixed_precision=True,
|
|
||||||
)
|
)
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=conf.device_num,
|
device_num=conf.device_num,
|
||||||
|
|
|
@ -114,7 +114,7 @@ def train():
|
||||||
device_num = 1
|
device_num = 1
|
||||||
num_workers = 8
|
num_workers = 8
|
||||||
if config.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
context.set_context(device_id=get_device_id())
|
context.set_context(device_id=get_device_id())
|
||||||
if config.distribute == "true":
|
if config.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
|
|
|
@ -108,10 +108,11 @@ def train():
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
num_workers = 8
|
num_workers = 8
|
||||||
|
|
||||||
if config.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
context.set_context(device_id=get_device_id())
|
context.set_context(device_id=get_device_id())
|
||||||
if config.distribute == "true":
|
if config.distribute == "true":
|
||||||
|
|
||||||
D.init()
|
D.init()
|
||||||
device_num = get_device_num()
|
device_num = get_device_num()
|
||||||
rank = get_rank_id()
|
rank = get_rank_id()
|
||||||
|
|
|
@ -109,7 +109,7 @@ def train():
|
||||||
device_num = 1
|
device_num = 1
|
||||||
num_workers = 8
|
num_workers = 8
|
||||||
if config.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
context.set_context(device_id=get_device_id())
|
context.set_context(device_id=get_device_id())
|
||||||
if config.distribute == "true":
|
if config.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
|
|
|
@ -136,7 +136,6 @@ def train():
|
||||||
device_num = 1
|
device_num = 1
|
||||||
num_workers = 8
|
num_workers = 8
|
||||||
if args_opt.device_target == "Ascend":
|
if args_opt.device_target == "Ascend":
|
||||||
context.set_context(enable_auto_mixed_precision=False)
|
|
||||||
context.set_context(device_id=args_opt.device_id)
|
context.set_context(device_id=args_opt.device_id)
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
|
|
|
@ -102,7 +102,7 @@ def train():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=args.device_target)
|
||||||
if args.device_target != "CPU":
|
if args.device_target != "CPU":
|
||||||
context.set_context(enable_auto_mixed_precision=True, device_id=args.device_id)
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
# init multicards training
|
# init multicards training
|
||||||
if args.modelArts_mode:
|
if args.modelArts_mode:
|
||||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
||||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
print(rank_size)
|
print(rank_size)
|
||||||
device_num = rank_size
|
device_num = rank_size
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
|
@ -83,8 +83,7 @@ if __name__ == '__main__':
|
||||||
local_data_path = args_opt.data_url
|
local_data_path = args_opt.data_url
|
||||||
print('Download data:')
|
print('Download data:')
|
||||||
dataset = create_dataset(dataset_path=local_data_path,
|
dataset = create_dataset(dataset_path=local_data_path,
|
||||||
do_train=True,
|
do_train=True)
|
||||||
target="Ascend")
|
|
||||||
|
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
print('steps:', step_size)
|
print('steps:', step_size)
|
||||||
|
|
|
@ -77,7 +77,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True,
|
gradients_mean=True,
|
||||||
auto_parallel_search_mode="recursive_programming")
|
auto_parallel_search_mode="recursive_programming")
|
||||||
|
|
|
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -56,8 +56,7 @@ if args.isModelArts:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
target = args.device_target
|
target = args.device_target
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=target,
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
enable_auto_mixed_precision=True, save_graphs=False)
|
|
||||||
|
|
||||||
if args.distribute:
|
if args.distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
||||||
step = 60
|
step = 60
|
||||||
target = args.device_target
|
target = args.device_target
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
context.set_context(device_id=args.device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
|
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
|
||||||
net = resnet50_ibn_a(num_classes=cfg.class_num)
|
net = resnet50_ibn_a(num_classes=cfg.class_num)
|
||||||
|
|
|
@ -97,8 +97,7 @@ if __name__ == "__main__":
|
||||||
if args.device_num > 1:
|
if args.device_num > 1:
|
||||||
if target == 'Ascend':
|
if target == 'Ascend':
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True,
|
gradients_mean=True,
|
||||||
auto_parallel_search_mode="recursive_programming")
|
auto_parallel_search_mode="recursive_programming")
|
||||||
|
|
|
@ -93,7 +93,7 @@ def train(mixdata_path):
|
||||||
load_path = config.train_data_dir + '/midas/ckpt/midas_resnext_101_WSL.ckpt'
|
load_path = config.train_data_dir + '/midas/ckpt/midas_resnext_101_WSL.ckpt'
|
||||||
device_id = config.device_id
|
device_id = config.device_id
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id,
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id,
|
||||||
enable_auto_mixed_precision=True, max_call_depth=10000)
|
max_call_depth=10000)
|
||||||
# load data
|
# load data
|
||||||
f = open(mixdata_path)
|
f = open(mixdata_path)
|
||||||
data_config = json.load(f)
|
data_config = json.load(f)
|
||||||
|
|
|
@ -80,7 +80,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
# init parallel training parameters
|
# init parallel training parameters
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
|
|
|
@ -181,7 +181,7 @@ def get_result(args, model, top1_correct, top5_correct, img_tot):
|
||||||
def test(cloud_args=None):
|
def test(cloud_args=None):
|
||||||
"""test"""
|
"""test"""
|
||||||
args = parse_args(cloud_args)
|
args = parse_args(cloud_args)
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=args.platform, save_graphs=False)
|
device_target=args.platform, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
|
@ -192,7 +192,7 @@ def parse_args(cloud_args=None):
|
||||||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||||
args.image_size = list(map(int, args.image_size.split(',')))
|
args.image_size = list(map(int, args.image_size.split(',')))
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=args.platform, save_graphs=False)
|
device_target=args.platform, save_graphs=False)
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
|
|
|
@ -65,7 +65,7 @@ if __name__ == '__main__':
|
||||||
context.set_ps_context(enable_ps=True)
|
context.set_ps_context(enable_ps=True)
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
context.set_context(device_id=device_id)
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
|
|
|
@ -69,8 +69,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=args_opt.device_num,
|
device_num=args_opt.device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
|
|
@ -67,8 +67,7 @@ if __name__ == '__main__':
|
||||||
device_id = int(os.getenv("DEVICE_ID"))
|
device_id = int(os.getenv("DEVICE_ID"))
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=target)
|
device_target=target)
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id)
|
||||||
enable_auto_mixed_precision=True)
|
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=device_num,
|
device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
|
|
@ -136,7 +136,7 @@ def run_eval():
|
||||||
|
|
||||||
_enable_graph_kernel = config.device_target == "GPU"
|
_enable_graph_kernel = config.device_target == "GPU"
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
|
||||||
enable_auto_mixed_precision=True, device_target=config.device_target, save_graphs=False)
|
device_target=config.device_target, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit() and config.device_target == "Ascend":
|
if os.getenv('DEVICE_ID', "not_set").isdigit() and config.device_target == "Ascend":
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ def set_default():
|
||||||
config.ann_val_file = os.path.join(config.data_dir, 'annotations/instances_val2017.json')
|
config.ann_val_file = os.path.join(config.data_dir, 'annotations/instances_val2017.json')
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=config.device_target, save_graphs=False, device_id=device_id)
|
device_target=config.device_target, save_graphs=False, device_id=device_id)
|
||||||
|
|
||||||
if config.need_profiler:
|
if config.need_profiler:
|
||||||
|
|
|
@ -96,7 +96,7 @@ def modelarts_pre_process():
|
||||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def train():
|
def train():
|
||||||
'''Train.'''
|
'''Train.'''
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, enable_auto_mixed_precision=True)
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
config.rank_save_ckpt_flag = 1
|
config.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
|
|
|
@ -411,7 +411,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
||||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||||
|
|
||||||
added_global_step = self.global_step + self.one
|
added_global_step = self.global_step + self.one
|
||||||
F.control_depend(lr, added_global_step)
|
|
||||||
self.global_step = added_global_step
|
self.global_step = added_global_step
|
||||||
|
|
||||||
return updated_velocity
|
return updated_velocity
|
||||||
|
|
|
@ -130,8 +130,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
op_sqrt = P.Sqrt()
|
op_sqrt = P.Sqrt()
|
||||||
scatter_add = P.ScatterAdd(use_locking)
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
assign_m = F.assign(m, op_mul(beta1, m))
|
F.assign(m, op_mul(beta1, m))
|
||||||
assign_v = F.assign(v, op_mul(beta2, v))
|
F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
grad_indices = gradient.indices
|
grad_indices = gradient.indices
|
||||||
grad_value = gradient.values
|
grad_value = gradient.values
|
||||||
|
@ -146,17 +146,15 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
if use_nesterov:
|
if use_nesterov:
|
||||||
m_temp = next_m * _scaler_ten
|
m_temp = next_m * _scaler_ten
|
||||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
F.assign(m, op_mul(beta1, next_m))
|
||||||
div_value = scatter_add(m,
|
div_value = scatter_add(m,
|
||||||
op_mul(grad_indices, _scaler_one),
|
op_mul(grad_indices, _scaler_one),
|
||||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
|
|
||||||
F.control_depend(m_temp, assign_m_nesterov)
|
|
||||||
F.control_depend(assign_m_nesterov, div_value)
|
|
||||||
F.control_depend(param_update, m_recover)
|
|
||||||
else:
|
else:
|
||||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
@ -164,8 +162,6 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
next_param = param - lr_t * param_update
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
F.control_depend(assign_m, next_m)
|
|
||||||
F.control_depend(assign_v, next_v)
|
|
||||||
|
|
||||||
success = F.depend(success, F.assign(param, next_param))
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
success = F.depend(success, F.assign(m, next_m))
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
|
|
@ -129,8 +129,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
op_sqrt = P.Sqrt()
|
op_sqrt = P.Sqrt()
|
||||||
scatter_add = P.ScatterAdd(use_locking)
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
assign_m = F.assign(m, op_mul(beta1, m))
|
F.assign(m, op_mul(beta1, m))
|
||||||
assign_v = F.assign(v, op_mul(beta2, v))
|
F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
grad_indices = gradient.indices
|
grad_indices = gradient.indices
|
||||||
grad_value = gradient.values
|
grad_value = gradient.values
|
||||||
|
@ -145,17 +145,14 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
if use_nesterov:
|
if use_nesterov:
|
||||||
m_temp = next_m * _scaler_ten
|
m_temp = next_m * _scaler_ten
|
||||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
F.assign(m, op_mul(beta1, next_m))
|
||||||
div_value = scatter_add(m,
|
div_value = scatter_add(m,
|
||||||
op_mul(grad_indices, _scaler_one),
|
op_mul(grad_indices, _scaler_one),
|
||||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
F.control_depend(m_temp, assign_m_nesterov)
|
|
||||||
F.control_depend(assign_m_nesterov, div_value)
|
|
||||||
F.control_depend(param_update, m_recover)
|
|
||||||
else:
|
else:
|
||||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
@ -163,8 +160,7 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
next_param = param - lr_t * param_update
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
F.control_depend(assign_m, next_m)
|
|
||||||
F.control_depend(assign_v, next_v)
|
|
||||||
|
|
||||||
success = F.depend(success, F.assign(param, next_param))
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
success = F.depend(success, F.assign(m, next_m))
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
|
|
@ -129,8 +129,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
op_sqrt = P.Sqrt()
|
op_sqrt = P.Sqrt()
|
||||||
scatter_add = P.ScatterAdd(use_locking)
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
assign_m = F.assign(m, op_mul(beta1, m))
|
F.assign(m, op_mul(beta1, m))
|
||||||
assign_v = F.assign(v, op_mul(beta2, v))
|
F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
grad_indices = gradient.indices
|
grad_indices = gradient.indices
|
||||||
grad_value = gradient.values
|
grad_value = gradient.values
|
||||||
|
@ -145,17 +145,15 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
if use_nesterov:
|
if use_nesterov:
|
||||||
m_temp = next_m * _scaler_ten
|
m_temp = next_m * _scaler_ten
|
||||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
F.assign(m, op_mul(beta1, next_m))
|
||||||
div_value = scatter_add(m,
|
div_value = scatter_add(m,
|
||||||
op_mul(grad_indices, _scaler_one),
|
op_mul(grad_indices, _scaler_one),
|
||||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
|
|
||||||
F.control_depend(m_temp, assign_m_nesterov)
|
|
||||||
F.control_depend(assign_m_nesterov, div_value)
|
|
||||||
F.control_depend(param_update, m_recover)
|
|
||||||
else:
|
else:
|
||||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
@ -163,8 +161,7 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
||||||
|
|
||||||
next_param = param - lr_t * param_update
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
F.control_depend(assign_m, next_m)
|
|
||||||
F.control_depend(assign_v, next_v)
|
|
||||||
|
|
||||||
success = F.depend(success, F.assign(param, next_param))
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
success = F.depend(success, F.assign(m, next_m))
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
|
Loading…
Reference in New Issue