forked from mindspore-Ecosystem/mindspore
!12777 fix ci pynative case
From: @jojobugfree Reviewed-by: @chujinjin,@jjfeing Signed-off-by: @chujinjin
This commit is contained in:
commit
541643e16b
|
@ -32,7 +32,36 @@ from mindspore.nn import Cell
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as CP
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.nn.wrap.cell_wrapper import WithLossCell
|
||||
from mindspore.train.callback import LossMonitor, Callback
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.model import Model
|
||||
|
||||
class MyTimeMonitor(Callback):
|
||||
def __init__(self, data_size):
|
||||
super(MyTimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
self.total = 0
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_msseconds = (time.time()-self.epoch_time) * 1000
|
||||
per_step_mssconds = epoch_msseconds / self.data_size
|
||||
print("epoch time:{0}, per step time:{1}".format(epoch_msseconds, per_step_mssconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_msseconds = (time.time() - self.step_time) * 1000
|
||||
if step_msseconds < 265:
|
||||
self.total = self.total + 1
|
||||
print(f"step time:{step_msseconds}", flush=True)
|
||||
|
||||
def good_step(self):
|
||||
return self.total
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
@ -303,12 +332,12 @@ def resnet50(batch_size, num_classes):
|
|||
return ResNet(ResidualBlock, num_classes, batch_size)
|
||||
|
||||
|
||||
def create_dataset(repeat_num=1, training=True, batch_size=32):
|
||||
def create_dataset(repeat_num=1, training=True, batch_size=32, num_samples=1600):
|
||||
data_home = "/home/workspace/mindspore_dataset"
|
||||
data_dir = data_home + "/cifar-10-batches-bin"
|
||||
if not training:
|
||||
data_dir = data_home + "/cifar-10-verify-bin"
|
||||
data_set = ds.Cifar10Dataset(data_dir)
|
||||
data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples)
|
||||
|
||||
resize_height = 224
|
||||
resize_width = 224
|
||||
|
@ -385,33 +414,25 @@ def test_pynative_resnet50():
|
|||
|
||||
batch_size = 32
|
||||
num_classes = 10
|
||||
loss_scale = 128
|
||||
total_step = 50
|
||||
net = resnet50(batch_size, num_classes)
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Momentum(learning_rate=0.01, momentum=0.9,
|
||||
params=filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size, num_samples=total_step * batch_size)
|
||||
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
net_with_criterion.set_grad()
|
||||
train_network = GradWrap(net_with_criterion)
|
||||
train_network.set_train()
|
||||
# define callbacks
|
||||
time_cb = MyTimeMonitor(data_size=data_set.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
|
||||
step = 0
|
||||
max_step = 21
|
||||
exceed_num = 0
|
||||
data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size)
|
||||
for element in data_set.create_dict_iterator(num_epochs=1):
|
||||
step = step + 1
|
||||
if step > max_step:
|
||||
break
|
||||
start_time = time.time()
|
||||
input_data = element["image"]
|
||||
input_label = element["label"]
|
||||
loss_output = net_with_criterion(input_data, input_label)
|
||||
grads = train_network(input_data, input_label)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
cost_time = end_time - start_time
|
||||
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
if step > 1 and cost_time > 0.25:
|
||||
exceed_num = exceed_num + 1
|
||||
assert exceed_num < 20
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
|
||||
# train model
|
||||
model.train(1, data_set, callbacks=cb,
|
||||
sink_size=data_set.get_dataset_size(), dataset_sink_mode=True)
|
||||
|
||||
assert time_cb.good_step() > 10
|
||||
|
|
Loading…
Reference in New Issue