forked from mindspore-Ecosystem/mindspore
!28406 Add compile cache support for ms_function
Merge pull request !28406 from YuJianfeng/compile_cache1
This commit is contained in:
commit
449f52012e
|
@ -746,9 +746,9 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
}
|
||||
|
||||
void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
|
||||
// The compilation cache only support for training currently.
|
||||
// The compilation cache only support for training cell or ms_function currently.
|
||||
// If enable compilation cache, it will get a non-empty dependent files list from python.
|
||||
if (!IsPhaseTrain(phase) || compile_cache_dep_files_.empty()) {
|
||||
if (compile_cache_dep_files_.empty()) {
|
||||
return;
|
||||
}
|
||||
#ifdef ENABLE_PROFILE
|
||||
|
|
|
@ -215,6 +215,14 @@ class _MindsporeFunctionExecutor:
|
|||
init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
|
||||
_exec_init_graph(self.obj, init_phase)
|
||||
|
||||
def _set_compile_cache_dep_files(self):
|
||||
# If enable compile cache, get the dependency files list
|
||||
enable_compile_cache = context.get_context("enable_compile_cache")
|
||||
if enable_compile_cache is None:
|
||||
enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
|
||||
if enable_compile_cache is True or enable_compile_cache == "1":
|
||||
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
||||
|
||||
def compile(self, args_list, method_name):
|
||||
"""Returns pipeline for the given args."""
|
||||
# Verify the signature for both function and method
|
||||
|
@ -255,6 +263,9 @@ class _MindsporeFunctionExecutor:
|
|||
if phase in ms_compile_cache:
|
||||
return phase
|
||||
|
||||
# If enable compile cache, get the dependency files list and set to graph executor.
|
||||
self._set_compile_cache_dep_files()
|
||||
|
||||
if self.obj is None:
|
||||
is_compile = self._graph_executor.compile(self.fn, args_list, phase, True)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
|
||||
|
||||
@ms_function
|
||||
def func(input_x, input_y):
|
||||
output = input_x + input_x * input_y
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(enable_compile_cache=True, compile_cache_path=sys.argv[1])
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
res = func(x, y)
|
||||
print("{", res, "}")
|
||||
print("{", res.asnumpy().shape, "}")
|
||||
context.set_context(enable_compile_cache=False)
|
|
@ -277,3 +277,17 @@ def test_compile_cache_lenet_ps():
|
|||
Expectation: success.
|
||||
"""
|
||||
run_lenet_ps_twice("run_lenet_ps.py", "./lenet_ps", "lenet_ps_first.txt", "lenet_ps_second.txt")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_ms_function():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the compile cache function can run successfully in the compilation of ms_function.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_twice_with_same_network("run_lenet_ms_function.py", "./lenet_ms_function", "lenet_ms_function_first.txt",
|
||||
"lenet_ms_function_second.txt")
|
||||
|
|
Loading…
Reference in New Issue