From 79f86d3fe2819fc277fa0d8e7a7fc689a29a0cf1 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Sat, 27 Nov 2021 15:01:09 +0800 Subject: [PATCH] Add test case for compile cache --- tests/st/frontend_compile_cache/run_lenet.py | 76 ++++++++ .../run_network_with_control_flow.py | 55 ++++++ .../run_network_with_weights.py | 51 ++++++ .../test_compile_cache.py | 162 ++++++++++++++++++ 4 files changed, 344 insertions(+) create mode 100644 tests/st/frontend_compile_cache/run_lenet.py create mode 100644 tests/st/frontend_compile_cache/run_network_with_control_flow.py create mode 100644 tests/st/frontend_compile_cache/run_network_with_weights.py create mode 100644 tests/st/frontend_compile_cache/test_compile_cache.py diff --git a/tests/st/frontend_compile_cache/run_lenet.py b/tests/st/frontend_compile_cache/run_lenet.py new file mode 100644 index 00000000000..395bbab3a4f --- /dev/null +++ b/tests/st/frontend_compile_cache/run_lenet.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + + +class LeNet(nn.Cell): + def __init__(self): + super(LeNet, self).__init__() + self.relu = P.ReLU() + self.batch_size = 32 + + self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.reshape(output, (self.batch_size, -1)) + output = self.fc1(output) + output = self.relu(output) + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + + +def train(net, data, label): + learning_rate = 0.01 + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + res = train_network(data, label) + print("{", res, "}") + print("{", res.asnumpy().shape, "}") + + +if __name__ == "__main__": + context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1]) + input_data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + input_label = Tensor(np.ones([32]).astype(np.int32)) + lenet = LeNet() + train(lenet, input_data, input_label) + context.set_context(enable_compile_cache=False) diff --git a/tests/st/frontend_compile_cache/run_network_with_control_flow.py b/tests/st/frontend_compile_cache/run_network_with_control_flow.py new file mode 100644 index 00000000000..44bc8b1f5fc --- /dev/null +++ b/tests/st/frontend_compile_cache/run_network_with_control_flow.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import dtype as mstype +from mindspore.ops import operations as P + + +class NetWithControlFlow(nn.Cell): + def __init__(self): + super(NetWithControlFlow, self).__init__() + self.mul = P.Mul() + self.add = P.Add() + param_a = np.full((1,), 5, dtype=np.float32) + self.param_a = Parameter(Tensor(param_a), name='a') + param_b = np.full((1,), 4, dtype=np.float32) + self.param_b = Parameter(Tensor(param_b), name='b') + + def construct(self, x): + if self.param_a > self.param_b: + x = self.mul(x, 2) + for _ in range(0, 5): + x = self.add(x, x) + self.param_b += 1 + return x + + +def run_net_with_control_flow(): + x = Tensor([10], mstype.int32) + net = NetWithControlFlow() + output = net(x) + print("{", output, "}") + print("{", output.asnumpy().shape, "}") + + +if __name__ == "__main__": + context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1]) + run_net_with_control_flow() + context.set_context(enable_compile_cache=False) diff --git a/tests/st/frontend_compile_cache/run_network_with_weights.py b/tests/st/frontend_compile_cache/run_network_with_weights.py new file mode 100644 index 00000000000..b178b988a78 --- /dev/null +++ b/tests/st/frontend_compile_cache/run_network_with_weights.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import dtype as mstype +from mindspore.ops import operations as P + + +class NetWithWeights(nn.Cell): + def __init__(self): + super(NetWithWeights, self).__init__() + self.matmul = P.MatMul() + self.a = Parameter(Tensor(np.array([2.0], np.float32)), name='a') + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, x, y): + x = x * self.z + y = y * self.a + out = self.matmul(x, y) + return out + + +def run_simple_net(): + x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32) + y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32) + net = NetWithWeights() + output = net(x, y) + print("{", output, "}") + print("{", output.asnumpy().shape, "}") + + +if __name__ == "__main__": + context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1]) + run_simple_net() + context.set_context(enable_compile_cache=False) diff --git a/tests/st/frontend_compile_cache/test_compile_cache.py b/tests/st/frontend_compile_cache/test_compile_cache.py new file mode 100644 index 00000000000..9dd035943da --- /dev/null +++ b/tests/st/frontend_compile_cache/test_compile_cache.py @@ -0,0 +1,162 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import re +import shutil +import subprocess +import pytest +import numpy as np + +match_output = re.compile(r'[{](.*?)[}]', re.S) +match_num = re.compile(r'\d+\.?\d*', re.S) + + +def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_file_name_second): + # Clear compile cache folder and log files + if os.path.exists(cache_path): + shutil.rmtree(cache_path) + if os.path.exists(log_file_name_first): + os.remove(log_file_name_first) + if os.path.exists(log_file_name_second): + os.remove(log_file_name_second) + assert not os.path.exists(cache_path) + assert not os.path.exists(log_file_name_first) + assert not os.path.exists(log_file_name_second) + + # First run without compile cache + cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_first + " 2>&1" + subprocess.check_output(cmd_first, shell=True) + assert os.path.exists(log_file_name_first) + assert os.path.exists(cache_path) + with open(log_file_name_first, "r") as f_first: + data_first = f_first.read() + assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first + + # Take out the result of the first run + match_output_first = re.findall(match_output, data_first) + assert len(match_output_first) == 2 + nums_first = re.findall(match_num, match_output_first[0]) + array_first = np.array([float(x) for x in nums_first]) + shape_first = re.findall(match_num, match_output_first[1]) + array_shape_first = np.array([int(x) for x in shape_first]) + + # Second run with compile cache + cmd_second = cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_second +\ + " 2>&1" + subprocess.check_output(cmd_second, shell=True) + assert os.path.exists(log_file_name_second) + with open(log_file_name_second, "r") as f_second: + data_second = f_second.read() + assert "Use the compilation cache and execute the backend actions only. Be aware of correctness risks." in \ + data_second + + # Take out the result of the second run + match_output_second = re.findall(match_output, data_second) + assert len(match_output_second) == 2 + nums_second = re.findall(match_num, match_output_second[0]) + array_second = np.array([float(x) for x in nums_second]) + shape_second = re.findall(match_num, match_output_second[1]) + array_shape_second = np.array([int(x) for x in shape_second]) + + assert np.allclose(array_first, array_second, 0.0001, 0.0001) + assert (array_shape_first == array_shape_second).all() + + # Clean files + os.remove(log_file_name_first) + os.remove(log_file_name_second) + shutil.rmtree(cache_path) + + +def run_twice_with_different_networks(file_name_first, file_name_second, cache_path, log_file_name_first, + log_file_name_second): + # Clear compile cache folder + if os.path.exists(cache_path): + shutil.rmtree(cache_path) + assert not os.path.exists(cache_path) + + # First run without compile cache + cmd_first = f"GLOG_v=2 python " + file_name_first + " '" + cache_path + "' > " + log_file_name_first + " 2>&1" + subprocess.check_output(cmd_first, shell=True) + assert os.path.exists(log_file_name_first) + assert os.path.exists(cache_path) + with open(log_file_name_first, "r") as f_first: + data_first = f_first.read() + assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first + + # Second run with compile cache + cmd_second = f"GLOG_v=2 python " + file_name_second + " '" + cache_path + "' > " + log_file_name_second + " 2>&1" + subprocess.check_output(cmd_second, shell=True) + assert os.path.exists(log_file_name_second) + with open(log_file_name_second, "r") as f_second: + data_second = f_second.read() + assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_second + + # Clean log files + os.remove(log_file_name_first) + os.remove(log_file_name_second) + shutil.rmtree(cache_path) + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_compile_cache_load_weights(): + """ + Feature: Compile cache. + Description: Test whether the compile cache can load the value of parameters successfully. + Expectation: success. + """ + run_twice_with_same_network("run_network_with_weights.py", "./weight", "weight_first.txt", "weight_second.txt") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_compile_cache_lenet(): + """ + Feature: Compile cache. + Description: Test whether the regular compile cache function can run successfully. + Expectation: success. + """ + run_twice_with_same_network("run_lenet.py", "./lenet", "lenet_first.txt", "lenet_second.txt") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_compile_cache_net_with_control_flow(): + """ + Feature: Compile cache. + Description: Test whether the compile cache can load ref type parameter correctly. + Expectation: success. + """ + run_twice_with_same_network("run_network_with_control_flow.py", "./control_flow", "control_net_first.txt", + "control_net_second.txt") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_compile_cache_auto_detect(): + """ + Feature: Compile cache. + Description: Test whether the compile cache auto-detection function can run successfully. + Expectation: success. + """ + run_twice_with_different_networks("run_lenet.py", "run_network_with_weights.py", "./lenet_auto_detect", + "auto_detect_first.txt", "auto_detect_second.txt")