forked from mindspore-Ecosystem/mindspore
!26851 Add compile cache st test cases
Merge pull request !26851 from LiangZhibo/master
This commit is contained in:
commit
18163f07aa
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 32
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.conv2(output)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
print("{", res, "}")
|
||||
print("{", res.asnumpy().shape, "}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
|
||||
input_data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
input_label = Tensor(np.ones([32]).astype(np.int32))
|
||||
lenet = LeNet()
|
||||
train(lenet, input_data, input_label)
|
||||
context.set_context(enable_compile_cache=False)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetWithControlFlow(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithControlFlow, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.Add()
|
||||
param_a = np.full((1,), 5, dtype=np.float32)
|
||||
self.param_a = Parameter(Tensor(param_a), name='a')
|
||||
param_b = np.full((1,), 4, dtype=np.float32)
|
||||
self.param_b = Parameter(Tensor(param_b), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x = self.mul(x, 2)
|
||||
for _ in range(0, 5):
|
||||
x = self.add(x, x)
|
||||
self.param_b += 1
|
||||
return x
|
||||
|
||||
|
||||
def run_net_with_control_flow():
|
||||
x = Tensor([10], mstype.int32)
|
||||
net = NetWithControlFlow()
|
||||
output = net(x)
|
||||
print("{", output, "}")
|
||||
print("{", output.asnumpy().shape, "}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
|
||||
run_net_with_control_flow()
|
||||
context.set_context(enable_compile_cache=False)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetWithWeights(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithWeights, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.a = Parameter(Tensor(np.array([2.0], np.float32)), name='a')
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
y = y * self.a
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def run_simple_net():
|
||||
x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
|
||||
net = NetWithWeights()
|
||||
output = net(x, y)
|
||||
print("{", output, "}")
|
||||
print("{", output.asnumpy().shape, "}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
|
||||
run_simple_net()
|
||||
context.set_context(enable_compile_cache=False)
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
match_output = re.compile(r'[{](.*?)[}]', re.S)
|
||||
match_num = re.compile(r'\d+\.?\d*', re.S)
|
||||
|
||||
|
||||
def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_file_name_second):
|
||||
# Clear compile cache folder and log files
|
||||
if os.path.exists(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
if os.path.exists(log_file_name_first):
|
||||
os.remove(log_file_name_first)
|
||||
if os.path.exists(log_file_name_second):
|
||||
os.remove(log_file_name_second)
|
||||
assert not os.path.exists(cache_path)
|
||||
assert not os.path.exists(log_file_name_first)
|
||||
assert not os.path.exists(log_file_name_second)
|
||||
|
||||
# First run without compile cache
|
||||
cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_first + " 2>&1"
|
||||
subprocess.check_output(cmd_first, shell=True)
|
||||
assert os.path.exists(log_file_name_first)
|
||||
assert os.path.exists(cache_path)
|
||||
with open(log_file_name_first, "r") as f_first:
|
||||
data_first = f_first.read()
|
||||
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first
|
||||
|
||||
# Take out the result of the first run
|
||||
match_output_first = re.findall(match_output, data_first)
|
||||
assert len(match_output_first) == 2
|
||||
nums_first = re.findall(match_num, match_output_first[0])
|
||||
array_first = np.array([float(x) for x in nums_first])
|
||||
shape_first = re.findall(match_num, match_output_first[1])
|
||||
array_shape_first = np.array([int(x) for x in shape_first])
|
||||
|
||||
# Second run with compile cache
|
||||
cmd_second = cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_second +\
|
||||
" 2>&1"
|
||||
subprocess.check_output(cmd_second, shell=True)
|
||||
assert os.path.exists(log_file_name_second)
|
||||
with open(log_file_name_second, "r") as f_second:
|
||||
data_second = f_second.read()
|
||||
assert "Use the compilation cache and execute the backend actions only. Be aware of correctness risks." in \
|
||||
data_second
|
||||
|
||||
# Take out the result of the second run
|
||||
match_output_second = re.findall(match_output, data_second)
|
||||
assert len(match_output_second) == 2
|
||||
nums_second = re.findall(match_num, match_output_second[0])
|
||||
array_second = np.array([float(x) for x in nums_second])
|
||||
shape_second = re.findall(match_num, match_output_second[1])
|
||||
array_shape_second = np.array([int(x) for x in shape_second])
|
||||
|
||||
assert np.allclose(array_first, array_second, 0.0001, 0.0001)
|
||||
assert (array_shape_first == array_shape_second).all()
|
||||
|
||||
# Clean files
|
||||
os.remove(log_file_name_first)
|
||||
os.remove(log_file_name_second)
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
|
||||
def run_twice_with_different_networks(file_name_first, file_name_second, cache_path, log_file_name_first,
|
||||
log_file_name_second):
|
||||
# Clear compile cache folder
|
||||
if os.path.exists(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
assert not os.path.exists(cache_path)
|
||||
|
||||
# First run without compile cache
|
||||
cmd_first = f"GLOG_v=2 python " + file_name_first + " '" + cache_path + "' > " + log_file_name_first + " 2>&1"
|
||||
subprocess.check_output(cmd_first, shell=True)
|
||||
assert os.path.exists(log_file_name_first)
|
||||
assert os.path.exists(cache_path)
|
||||
with open(log_file_name_first, "r") as f_first:
|
||||
data_first = f_first.read()
|
||||
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_first
|
||||
|
||||
# Second run with compile cache
|
||||
cmd_second = f"GLOG_v=2 python " + file_name_second + " '" + cache_path + "' > " + log_file_name_second + " 2>&1"
|
||||
subprocess.check_output(cmd_second, shell=True)
|
||||
assert os.path.exists(log_file_name_second)
|
||||
with open(log_file_name_second, "r") as f_second:
|
||||
data_second = f_second.read()
|
||||
assert "Check the consistency of dependency files hash failed. Execute all the compilation actions." in data_second
|
||||
|
||||
# Clean log files
|
||||
os.remove(log_file_name_first)
|
||||
os.remove(log_file_name_second)
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_load_weights():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the compile cache can load the value of parameters successfully.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_twice_with_same_network("run_network_with_weights.py", "./weight", "weight_first.txt", "weight_second.txt")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_lenet():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the regular compile cache function can run successfully.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_twice_with_same_network("run_lenet.py", "./lenet", "lenet_first.txt", "lenet_second.txt")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_net_with_control_flow():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the compile cache can load ref type parameter correctly.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_twice_with_same_network("run_network_with_control_flow.py", "./control_flow", "control_net_first.txt",
|
||||
"control_net_second.txt")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_auto_detect():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the compile cache auto-detection function can run successfully.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_twice_with_different_networks("run_lenet.py", "run_network_with_weights.py", "./lenet_auto_detect",
|
||||
"auto_detect_first.txt", "auto_detect_second.txt")
|
Loading…
Reference in New Issue