!26851 Add compile cache st test cases

Merge pull request !26851 from LiangZhibo/master
This commit is contained in:
i-robot 2021-12-07 00:50:32 +00:00 committed by Gitee
commit 18163f07aa
4 changed files with 344 additions and 0 deletions

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import sys
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("{", res, "}")
print("{", res.asnumpy().shape, "}")
if __name__ == "__main__":
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
input_data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
input_label = Tensor(np.ones([32]).astype(np.int32))
lenet = LeNet()
train(lenet, input_data, input_label)
context.set_context(enable_compile_cache=False)

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import sys
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.ops import operations as P
class NetWithControlFlow(nn.Cell):
def __init__(self):
super(NetWithControlFlow, self).__init__()
self.mul = P.Mul()
self.add = P.Add()
param_a = np.full((1,), 5, dtype=np.float32)
self.param_a = Parameter(Tensor(param_a), name='a')
param_b = np.full((1,), 4, dtype=np.float32)
self.param_b = Parameter(Tensor(param_b), name='b')
def construct(self, x):
if self.param_a > self.param_b:
x = self.mul(x, 2)
for _ in range(0, 5):
x = self.add(x, x)
self.param_b += 1
return x
def run_net_with_control_flow():
x = Tensor([10], mstype.int32)
net = NetWithControlFlow()
output = net(x)
print("{", output, "}")
print("{", output.asnumpy().shape, "}")
if __name__ == "__main__":
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
run_net_with_control_flow()
context.set_context(enable_compile_cache=False)

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import sys
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.ops import operations as P
class NetWithWeights(nn.Cell):
def __init__(self):
super(NetWithWeights, self).__init__()
self.matmul = P.MatMul()
self.a = Parameter(Tensor(np.array([2.0], np.float32)), name='a')
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
y = y * self.a
out = self.matmul(x, y)
return out
def run_simple_net():
x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
net = NetWithWeights()
output = net(x, y)
print("{", output, "}")
print("{", output.asnumpy().shape, "}")
if __name__ == "__main__":
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
run_simple_net()
context.set_context(enable_compile_cache=False)

View File

@ -0,0 +1,162 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import re
import shutil
import subprocess
import pytest
import numpy as np
match_output = re.compile(r'[{](.*?)[}]', re.S)
match_num = re.compile(r'\d+\.?\d*', re.S)
def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_file_name_second):
# Clear compile cache folder and log files
if os.path.exists(cache_path):
shutil.rmtree(cache_path)
if os.path.exists(log_file_name_first):
os.remove(log_file_name_first)
if os.path.exists(log_file_name_second):
os.remove(log_file_name_second)
assert not os.path.exists(cache_path)
assert not os.path.exists(log_file_name_first)
assert not os.path.exists(log_file_name_second)
# First run without compile cache
cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_first + " 2>&1"
subprocess.check_output(cmd_first, shell=True)
assert os.path.exists(log_file_name_first)
assert os.path.exists(cache_path)
with open(log_file_name_first, "r") as f_first:
data_first = f_first.read()
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first
# Take out the result of the first run
match_output_first = re.findall(match_output, data_first)
assert len(match_output_first) == 2
nums_first = re.findall(match_num, match_output_first[0])
array_first = np.array([float(x) for x in nums_first])
shape_first = re.findall(match_num, match_output_first[1])
array_shape_first = np.array([int(x) for x in shape_first])
# Second run with compile cache
cmd_second = cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_second +\
" 2>&1"
subprocess.check_output(cmd_second, shell=True)
assert os.path.exists(log_file_name_second)
with open(log_file_name_second, "r") as f_second:
data_second = f_second.read()
assert "Use the compilation cache and execute the backend actions only. Be aware of correctness risks." in \
data_second
# Take out the result of the second run
match_output_second = re.findall(match_output, data_second)
assert len(match_output_second) == 2
nums_second = re.findall(match_num, match_output_second[0])
array_second = np.array([float(x) for x in nums_second])
shape_second = re.findall(match_num, match_output_second[1])
array_shape_second = np.array([int(x) for x in shape_second])
assert np.allclose(array_first, array_second, 0.0001, 0.0001)
assert (array_shape_first == array_shape_second).all()
# Clean files
os.remove(log_file_name_first)
os.remove(log_file_name_second)
shutil.rmtree(cache_path)
def run_twice_with_different_networks(file_name_first, file_name_second, cache_path, log_file_name_first,
log_file_name_second):
# Clear compile cache folder
if os.path.exists(cache_path):
shutil.rmtree(cache_path)
assert not os.path.exists(cache_path)
# First run without compile cache
cmd_first = f"GLOG_v=2 python " + file_name_first + " '" + cache_path + "' > " + log_file_name_first + " 2>&1"
subprocess.check_output(cmd_first, shell=True)
assert os.path.exists(log_file_name_first)
assert os.path.exists(cache_path)
with open(log_file_name_first, "r") as f_first:
data_first = f_first.read()
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first
# Second run with compile cache
cmd_second = f"GLOG_v=2 python " + file_name_second + " '" + cache_path + "' > " + log_file_name_second + " 2>&1"
subprocess.check_output(cmd_second, shell=True)
assert os.path.exists(log_file_name_second)
with open(log_file_name_second, "r") as f_second:
data_second = f_second.read()
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_second
# Clean log files
os.remove(log_file_name_first)
os.remove(log_file_name_second)
shutil.rmtree(cache_path)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_compile_cache_load_weights():
"""
Feature: Compile cache.
Description: Test whether the compile cache can load the value of parameters successfully.
Expectation: success.
"""
run_twice_with_same_network("run_network_with_weights.py", "./weight", "weight_first.txt", "weight_second.txt")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_compile_cache_lenet():
"""
Feature: Compile cache.
Description: Test whether the regular compile cache function can run successfully.
Expectation: success.
"""
run_twice_with_same_network("run_lenet.py", "./lenet", "lenet_first.txt", "lenet_second.txt")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_compile_cache_net_with_control_flow():
"""
Feature: Compile cache.
Description: Test whether the compile cache can load ref type parameter correctly.
Expectation: success.
"""
run_twice_with_same_network("run_network_with_control_flow.py", "./control_flow", "control_net_first.txt",
"control_net_second.txt")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_compile_cache_auto_detect():
"""
Feature: Compile cache.
Description: Test whether the compile cache auto-detection function can run successfully.
Expectation: success.
"""
run_twice_with_different_networks("run_lenet.py", "run_network_with_weights.py", "./lenet_auto_detect",
"auto_detect_first.txt", "auto_detect_second.txt")