This commit is contained in:
unknown 2020-05-29 03:17:32 +08:00
parent ec7cbb9929
commit 913b5b03df
5 changed files with 20 additions and 36 deletions

View File

@ -0,0 +1,8 @@
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__= [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta",
"resnet50_dl"
]

View File

@ -532,3 +532,6 @@ class RootBlockBeta(nn.Cell):
x = self.conv2(x)
x = self.conv3(x)
return x
class resnet50_dl(fine_tune_batch_norm=False):
return ResNetV1(fine_tune_batch_norm)

View File

@ -17,7 +17,7 @@ import abc
import os
import time
from .utils.adapter import get_manifest_samples, get_raw_samples, read_image
from .utils.adapter import get_raw_samples, read_image
class BaseDataset(object):
@ -62,29 +62,6 @@ class BaseDataset(object):
pass
class HwVocManifestDataset(BaseDataset):
"""
Create dataset with manifest data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_manifest_samples(self.data_url, self.usage)
except Exception as e:
print("load HwVocManifestDataset samples failed!!!")
raise e
class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.

View File

@ -17,7 +17,7 @@ from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from .ei_dataset import HwVocManifestDataset, HwVocRawDataset
from .ei_dataset import HwVocRawDataset
from .utils import custom_transforms as tr
@ -77,9 +77,6 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
Dataset.
"""
# create iter dataset
if data_url.endswith(".manifest"):
dataset = HwVocManifestDataset(data_url, usage=usage)
else:
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)
@ -100,5 +97,4 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
dataset = dataset.repeat(count=epoch_num)
dataset.map_model = 4
dataset.__loop_size__ = 1
return dataset

View File

@ -87,13 +87,13 @@ if __name__ == "__main__":
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates,
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride,
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid)
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid)
net.set_train()
model_fine_tune(args_opt, net, 'layer')
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay)
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
model = Model(net, loss, opt)
model.train(args_opt.epoch_size, train_dataset, callback)