forked from mindspore-Ecosystem/mindspore
modify
This commit is contained in:
parent
ec7cbb9929
commit
913b5b03df
|
@ -0,0 +1,8 @@
|
|||
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
|
||||
RootBlockBeta, resnet50_dl
|
||||
|
||||
__all__= [
|
||||
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta",
|
||||
"resnet50_dl"
|
||||
]
|
||||
|
|
@ -532,3 +532,6 @@ class RootBlockBeta(nn.Cell):
|
|||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
class resnet50_dl(fine_tune_batch_norm=False):
|
||||
return ResNetV1(fine_tune_batch_norm)
|
||||
|
|
|
@ -17,7 +17,7 @@ import abc
|
|||
import os
|
||||
import time
|
||||
|
||||
from .utils.adapter import get_manifest_samples, get_raw_samples, read_image
|
||||
from .utils.adapter import get_raw_samples, read_image
|
||||
|
||||
|
||||
class BaseDataset(object):
|
||||
|
@ -62,29 +62,6 @@ class BaseDataset(object):
|
|||
pass
|
||||
|
||||
|
||||
class HwVocManifestDataset(BaseDataset):
|
||||
"""
|
||||
Create dataset with manifest data.
|
||||
|
||||
Args:
|
||||
data_url (str): The path of data.
|
||||
usage (str): Whether to use train or eval (default='train').
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data_url, usage="train"):
|
||||
super().__init__(data_url, usage)
|
||||
|
||||
def _load_samples(self):
|
||||
try:
|
||||
self.samples = get_manifest_samples(self.data_url, self.usage)
|
||||
except Exception as e:
|
||||
print("load HwVocManifestDataset samples failed!!!")
|
||||
raise e
|
||||
|
||||
|
||||
class HwVocRawDataset(BaseDataset):
|
||||
"""
|
||||
Create dataset with raw data.
|
||||
|
|
|
@ -17,7 +17,7 @@ from PIL import Image
|
|||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
|
||||
from .ei_dataset import HwVocManifestDataset, HwVocRawDataset
|
||||
from .ei_dataset import HwVocRawDataset
|
||||
from .utils import custom_transforms as tr
|
||||
|
||||
|
||||
|
@ -77,10 +77,7 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
|
|||
Dataset.
|
||||
"""
|
||||
# create iter dataset
|
||||
if data_url.endswith(".manifest"):
|
||||
dataset = HwVocManifestDataset(data_url, usage=usage)
|
||||
else:
|
||||
dataset = HwVocRawDataset(data_url, usage=usage)
|
||||
dataset = HwVocRawDataset(data_url, usage=usage)
|
||||
dataset_len = len(dataset)
|
||||
|
||||
# wrapped with GeneratorDataset
|
||||
|
@ -100,5 +97,4 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
|
|||
dataset = dataset.repeat(count=epoch_num)
|
||||
dataset.map_model = 4
|
||||
|
||||
dataset.__loop_size__ = 1
|
||||
return dataset
|
||||
|
|
|
@ -87,13 +87,13 @@ if __name__ == "__main__":
|
|||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
|
||||
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates,
|
||||
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride,
|
||||
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid)
|
||||
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
|
||||
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
||||
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride,
|
||||
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid)
|
||||
net.set_train()
|
||||
model_fine_tune(args_opt, net, 'layer')
|
||||
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label)
|
||||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay)
|
||||
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
|
||||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
||||
model = Model(net, loss, opt)
|
||||
model.train(args_opt.epoch_size, train_dataset, callback)
|
Loading…
Reference in New Issue