!22225 fix midas pynative issue

Merge pull request !22225 from JichenZhao/master
This commit is contained in:
i-robot 2021-08-25 11:25:51 +00:00 committed by Gitee
commit c24dc871e0
4 changed files with 8 additions and 5 deletions

View File

@ -121,7 +121,7 @@ batch_size: 2
loss_scale: 256
momentum: 0.91
weight_decay: 0.00001
epoch_size: 20
epoch_size: 12
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 5
@ -162,7 +162,7 @@ device_num: 1
rank_id: 0
image_dir: ''
anno_path: ''
backbone: 'resnet_v1_50'
backbone: 'resnet_v1.5_50'
# eval.py FasterRcnn evaluation
ann_file: '/cache/data/annotations/instances_val2017.json'

View File

@ -44,6 +44,7 @@ if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"):
from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet
elif config.backbone == "resnet_v1_50":
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
config.epoch_size = 20
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)

View File

@ -21,6 +21,7 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
@ -29,6 +30,8 @@ from src.midas_net import MidasNet, Loss, NetwithCell
from src.utils import loadImgDepth
from src.config import config
set_seed(1)
ds.config.set_seed(1)
def dynamic_lr(num_epoch_per_decay, total_epochs, steps_per_epoch, lr, end_lr):
"""dynamic learning rate generator"""
@ -78,8 +81,7 @@ def train(mixdata_path):
max_call_depth=10000)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=device_num,
parameter_broadcast=True
device_num=device_num
)
init()
local_data_path = config.train_data_dir + '/data'

View File

@ -24,7 +24,7 @@ set -e
RANK_SIZE=$1
export RANK_SIZE
export HCCL_CONNECT_TIMEOUT=600
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"