add new auto mix presicion testcase and check the cast num

Signed-off-by: zhushujing <zhushujing@huawei.com>
This commit is contained in:
zhushujing 2021-01-20 20:44:03 +08:00
parent 55594db01c
commit 61a43cd6eb
2 changed files with 246 additions and 1 deletions

View File

@ -13,14 +13,31 @@
# limitations under the License.
"""Test network turn on mix_precision."""
import os
import re
import pytest
import numpy as np
from mindspore.common import dtype
from mindspore import nn
from mindspore import ops
from mindspore import amp
from mindspore import Tensor
from mindspore import context
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import Model
from utils import FakeData
from utils import allclose_nparray
from utils import FakeDataInitMode
from utils import find_newest_validateir_file
from utils import clean_all_ir_files
def read_validateir_file(path_folder):
filename = find_newest_validateir_file(path_folder)
with open(os.path.join(filename), 'r') as f:
contend = f.read()
clean_all_ir_files(path_folder)
return contend
class Net(nn.Cell):
@ -62,7 +79,7 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_auto_mix_precision():
def test_sit_auto_mix_precision_train_o3():
input_data = np.random.randn(32, 3, 224, 224).astype(np.float64)
label_data = np.random.randn(32, 10).astype(np.float32)
# graph mode
@ -87,3 +104,74 @@ def test_auto_mix_precision():
drop_overflow_update=False))
out_pynative = train_network_pynative(Tensor(input_data), Tensor(label_data))
assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sit_auto_mix_precision_model_o0():
input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
dataset1 = FakeData(size=32,
batch_size=32,
image_size=(3, 224, 224),
num_classes=10,
fakedata_mode=FakeDataInitMode.OnesInit)
dataset1.set_label_data_type(np.float16)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True, save_graphs_path='./test_amp_o0')
net = Net(3, 10)
net.to_float(dtype.float16)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
model = Model(net, loss, opt, amp_level="O0")
model.train(1, dataset1, dataset_sink_mode=False)
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)
assert len(castnum) == 17
model.predict(Tensor(input_data))
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)
assert len(castnum) == 11
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sit_auto_mix_precision_model_o2():
input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
dataset1 = FakeData(size=32,
batch_size=32,
image_size=(3, 224, 224),
num_classes=10,
fakedata_mode=FakeDataInitMode.OnesInit)
dataset2 = FakeData(size=32,
batch_size=32,
image_size=(3, 224, 224),
num_classes=10,
fakedata_mode=FakeDataInitMode.OnesInit)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True, save_graphs_path='./test_amp_o2')
net = Net(3, 10)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
model = Model(net, loss, opt, amp_level="O2")
model.train(1, dataset1, dataset_sink_mode=False)
contend = read_validateir_file('./test_amp_o2')
castnum = re.findall("Cast", contend)
assert len(castnum) == 14
out_graph = model.predict(Tensor(input_data))
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net_pynative = Net(3, 10)
opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009)
loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O2")
model_pynative.train(1, dataset2, dataset_sink_mode=False)
out_pynative = model_pynative.predict(Tensor(input_data))
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" create train dataset. """
import os
import re
import numpy as np
from mindspore.communication.management import init
from mindspore.communication.management import get_rank
from mindspore.communication.management import get_group_size
from mindspore import Tensor
def _count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_me) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
if np.any(np.isnan(data_expected)):
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
_count_unequal_element(data_expected, data_me, rtol, atol)
else:
assert True
def clean_all_ir_files(folder_path):
if os.path.exists(folder_path):
for file_name in os.listdir(folder_path):
if file_name.endswith('.ir') or file_name.endswith('.dat') or file_name.endswith('.dot'):
os.remove(os.path.join(folder_path, file_name))
def find_newest_validateir_file(folder_path):
validate_files = map(lambda f: os.path.join(folder_path, f),
filter(lambda f: re.match(r'\d+_validate_\d+.ir', f), os.listdir(folder_path)))
return max(validate_files, key=os.path.getctime)
class FakeDataInitMode:
RandomInit = 0
OnesInit = 1
UniqueInit = 2
ZerosInit = 3
class FakeData:
def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
num_classes=10, random_offset=0, use_parallel=False,
fakedata_mode=FakeDataInitMode.RandomInit):
self.size = size
self.rank_batch_size = batch_size
self.total_batch_size = self.rank_batch_size
self.random_offset = random_offset
self.image_size = image_size
self.num_classes = num_classes
self.rank_size = 1
self.rank_id = 0
self.batch_index = 0
self.image_data_type = np.float32
self.label_data_type = np.float32
self.is_onehot = True
self.fakedata_mode = fakedata_mode
if use_parallel is True:
init()
self.rank_size = get_group_size()
self.rank_id = get_rank()
self.total_batch_size = self.rank_batch_size * self.rank_size
assert (self.size % self.total_batch_size) == 0
self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
def get_dataset_size(self):
return int(self.size / self.total_batch_size)
def get_repeat_count(self):
return 1
def set_image_data_type(self, data_type):
self.image_data_type = data_type
def set_label_data_type(self, data_type):
self.label_data_type = data_type
def set_label_onehot(self, is_onehot=True):
self.is_onehot = is_onehot
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
_ = num_epochs
return self
def __getitem__(self, batch_index):
if batch_index * self.total_batch_size >= len(self):
raise IndexError("{} index out of range".format(self.__class__.__name__))
rng_state = np.random.get_state()
np.random.seed(batch_index + self.random_offset)
if self.fakedata_mode == FakeDataInitMode.OnesInit:
img = np.ones(self.total_batch_data_size)
elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
img = np.zeros(self.total_batch_data_size)
elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
total_size = 1
for i in self.total_batch_data_size:
total_size = total_size * i
img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
else:
img = np.random.randn(*self.total_batch_data_size)
target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
np.random.set_state(rng_state)
img = img[self.rank_id]
target = target[self.rank_id]
img_ret = img.astype(self.image_data_type)
target_ret = target.astype(self.label_data_type)
if self.is_onehot:
target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
target_onehot[np.arange(self.rank_batch_size), target] = 1
target_ret = target_onehot.astype(self.label_data_type)
return Tensor(img_ret), Tensor(target_ret)
def __len__(self):
return self.size
def __iter__(self):
self.batch_index = 0
return self
def reset(self):
self.batch_index = 0
def __next__(self):
if self.batch_index * self.total_batch_size < len(self):
data = self[self.batch_index]
self.batch_index += 1
return data
raise StopIteration