fix se-resnet performance and ctpn standalone scripts

This commit is contained in:
qujianwei 2021-05-08 15:24:11 +08:00
parent dcec57955c
commit ca895d2d6e
2 changed files with 11 additions and 4 deletions

View File

@ -13,9 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_standalone_train.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]"
echo "for example: sh run_standalone_train.sh Pretraining /path/vgg16_backbone.ckpt 0"
echo "when device id is occupied, choose for another one"
echo "It is better to use absolute path."
echo "=============================================================================================================="
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
echo "Usage: sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]"
exit 1
fi
@ -38,7 +45,7 @@ fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export DEVICE_ID=$3
export RANK_ID=0
export RANK_SIZE=1

View File

@ -218,7 +218,7 @@ if __name__ == '__main__':
metrics = {"acc"}
if args_opt.run_distribute:
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)}
if (args_opt.net not in ("resnet18", "resnet50", "resnet101")) or \
if (args_opt.net not in ("resnet18", "resnet50", "resnet101", "se-resnet50")) or \
args_opt.parameter_server or target == "CPU":
## fp32 training
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)