forked from mindspore-Ecosystem/mindspore
!22225 fix midas pynative issue
Merge pull request !22225 from JichenZhao/master
This commit is contained in:
commit
c24dc871e0
|
@ -121,7 +121,7 @@ batch_size: 2
|
|||
loss_scale: 256
|
||||
momentum: 0.91
|
||||
weight_decay: 0.00001
|
||||
epoch_size: 20
|
||||
epoch_size: 12
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 5
|
||||
|
@ -162,7 +162,7 @@ device_num: 1
|
|||
rank_id: 0
|
||||
image_dir: ''
|
||||
anno_path: ''
|
||||
backbone: 'resnet_v1_50'
|
||||
backbone: 'resnet_v1.5_50'
|
||||
|
||||
# eval.py FasterRcnn evaluation
|
||||
ann_file: '/cache/data/annotations/instances_val2017.json'
|
||||
|
|
|
@ -44,6 +44,7 @@ if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"):
|
|||
from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet
|
||||
elif config.backbone == "resnet_v1_50":
|
||||
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
|
||||
config.epoch_size = 20
|
||||
|
||||
if config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
|
||||
|
@ -29,6 +30,8 @@ from src.midas_net import MidasNet, Loss, NetwithCell
|
|||
from src.utils import loadImgDepth
|
||||
from src.config import config
|
||||
|
||||
set_seed(1)
|
||||
ds.config.set_seed(1)
|
||||
|
||||
def dynamic_lr(num_epoch_per_decay, total_epochs, steps_per_epoch, lr, end_lr):
|
||||
"""dynamic learning rate generator"""
|
||||
|
@ -78,8 +81,7 @@ def train(mixdata_path):
|
|||
max_call_depth=10000)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
device_num=device_num,
|
||||
parameter_broadcast=True
|
||||
device_num=device_num
|
||||
)
|
||||
init()
|
||||
local_data_path = config.train_data_dir + '/data'
|
||||
|
|
|
@ -24,7 +24,7 @@ set -e
|
|||
|
||||
RANK_SIZE=$1
|
||||
export RANK_SIZE
|
||||
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
|
|
Loading…
Reference in New Issue