forked from mindspore-Ecosystem/mindspore
modify
This commit is contained in:
parent
ec7cbb9929
commit
913b5b03df
|
@ -0,0 +1,8 @@
|
||||||
|
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
|
||||||
|
RootBlockBeta, resnet50_dl
|
||||||
|
|
||||||
|
__all__= [
|
||||||
|
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta",
|
||||||
|
"resnet50_dl"
|
||||||
|
]
|
||||||
|
|
|
@ -532,3 +532,6 @@ class RootBlockBeta(nn.Cell):
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class resnet50_dl(fine_tune_batch_norm=False):
|
||||||
|
return ResNetV1(fine_tune_batch_norm)
|
||||||
|
|
|
@ -17,7 +17,7 @@ import abc
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from .utils.adapter import get_manifest_samples, get_raw_samples, read_image
|
from .utils.adapter import get_raw_samples, read_image
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(object):
|
class BaseDataset(object):
|
||||||
|
@ -62,29 +62,6 @@ class BaseDataset(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HwVocManifestDataset(BaseDataset):
|
|
||||||
"""
|
|
||||||
Create dataset with manifest data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_url (str): The path of data.
|
|
||||||
usage (str): Whether to use train or eval (default='train').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_url, usage="train"):
|
|
||||||
super().__init__(data_url, usage)
|
|
||||||
|
|
||||||
def _load_samples(self):
|
|
||||||
try:
|
|
||||||
self.samples = get_manifest_samples(self.data_url, self.usage)
|
|
||||||
except Exception as e:
|
|
||||||
print("load HwVocManifestDataset samples failed!!!")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class HwVocRawDataset(BaseDataset):
|
class HwVocRawDataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
Create dataset with raw data.
|
Create dataset with raw data.
|
||||||
|
|
|
@ -17,7 +17,7 @@ from PIL import Image
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
|
|
||||||
from .ei_dataset import HwVocManifestDataset, HwVocRawDataset
|
from .ei_dataset import HwVocRawDataset
|
||||||
from .utils import custom_transforms as tr
|
from .utils import custom_transforms as tr
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,10 +77,7 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
|
||||||
Dataset.
|
Dataset.
|
||||||
"""
|
"""
|
||||||
# create iter dataset
|
# create iter dataset
|
||||||
if data_url.endswith(".manifest"):
|
dataset = HwVocRawDataset(data_url, usage=usage)
|
||||||
dataset = HwVocManifestDataset(data_url, usage=usage)
|
|
||||||
else:
|
|
||||||
dataset = HwVocRawDataset(data_url, usage=usage)
|
|
||||||
dataset_len = len(dataset)
|
dataset_len = len(dataset)
|
||||||
|
|
||||||
# wrapped with GeneratorDataset
|
# wrapped with GeneratorDataset
|
||||||
|
@ -100,5 +97,4 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
|
||||||
dataset = dataset.repeat(count=epoch_num)
|
dataset = dataset.repeat(count=epoch_num)
|
||||||
dataset.map_model = 4
|
dataset.map_model = 4
|
||||||
|
|
||||||
dataset.__loop_size__ = 1
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -87,13 +87,13 @@ if __name__ == "__main__":
|
||||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
||||||
callback.append(ckpoint_cb)
|
callback.append(ckpoint_cb)
|
||||||
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
|
||||||
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates,
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
||||||
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride,
|
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride,
|
||||||
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid)
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid)
|
||||||
net.set_train()
|
net.set_train()
|
||||||
model_fine_tune(args_opt, net, 'layer')
|
model_fine_tune(args_opt, net, 'layer')
|
||||||
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label)
|
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
|
||||||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay)
|
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
||||||
model = Model(net, loss, opt)
|
model = Model(net, loss, opt)
|
||||||
model.train(args_opt.epoch_size, train_dataset, callback)
|
model.train(args_opt.epoch_size, train_dataset, callback)
|
Loading…
Reference in New Issue