GE support PyNative ms_function.
Update Chinese docs. Opt
This commit is contained in:
parent
666428a765
commit
bcbf41238c
|
@ -9,11 +9,12 @@ mindspore.JitConfig
|
|||
- 这是一个实验性接口,后续可能删除或修改。
|
||||
|
||||
参数:
|
||||
- **jit_level** (str) - 设置编译优化的级别,支持["O0", "O1", "O2"]。默认值:"O1"。
|
||||
- **jit_level** (str) - 设置编译优化的级别,支持["O0", "O1", "O2", "O3"]。默认值:"O1"。
|
||||
|
||||
- "O0": 基础优化。
|
||||
- "O1": 手动优化。
|
||||
- "O2": 手动优化与图算优化结合。
|
||||
- "O3": 性能优化,无法保证泛化性。
|
||||
|
||||
- **task_sink** (bool) - 数据是否直接下沉至处理器进行处理。默认值:True。
|
||||
- **kwargs** (dict) - 关键字参数字典。
|
||||
|
|
|
@ -311,9 +311,11 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
|
|||
func_graph_to_kernel_graph_ids_.clear();
|
||||
control_nodes_.clear();
|
||||
|
||||
auto jit_level = common::AnfAlgo::GetJitLevel(func_graph);
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}, jit_level);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device_context->Initialize();
|
||||
bool all_support = device_context->PartitionGraph(func_graph);
|
||||
if (all_support) {
|
||||
auto run_mode = device_context->GetRunMode(func_graph);
|
||||
|
@ -331,7 +333,8 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
|
|||
// Construct the graph compiler info.
|
||||
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
||||
if (ms_execution_mode_ == kGraphMode &&
|
||||
if ((ms_execution_mode_ == kGraphMode ||
|
||||
(ms_execution_mode_ == kPynativeMode && jit_level == "O3" && context_ptr->backend_policy() == "ge")) &&
|
||||
((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
ParseControlNodes(*graph_compiler_info);
|
||||
|
|
|
@ -286,6 +286,9 @@ class COMMON_EXPORT AnfAlgo {
|
|||
|
||||
static std::string GetTensorValueString(const tensor::TensorPtr &tensor);
|
||||
static abstract::AbstractBasePtr GetNodeAbstractByIndex(const AnfNodePtr &node, size_t index);
|
||||
|
||||
// Get jit level from func_graph
|
||||
static std::string GetJitLevel(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -681,6 +681,7 @@ constexpr auto kAttrNeedGradFlagOfInputs = "need_grad_flag_of_inputs";
|
|||
constexpr auto kAttrIsCNodeNeedGrad = "is_cnode_need_grad";
|
||||
constexpr auto kAttrJitLevel = "jit_level";
|
||||
constexpr auto kAttrJitLevelO2 = "O2";
|
||||
constexpr auto kAttrJitLevelO3 = "O3";
|
||||
constexpr auto kAttrCellJitConfigDict = "_jit_config_dict";
|
||||
constexpr auto kAttrBinaryOutput = "binary_output";
|
||||
constexpr auto kAttrMinLength = "minlength";
|
||||
|
|
|
@ -392,8 +392,9 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
|
|||
}
|
||||
|
||||
// todo: waiting for GraphExecutor
|
||||
auto jit_level = common::AnfAlgo::GetJitLevel(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->backend_policy() == "ge") {
|
||||
if (MsContext::GetInstance()->backend_policy() == "ge" && (jit_level == "O3" || jit_level == "")) {
|
||||
auto manager = MakeManager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (const auto &graph : all_graphs) {
|
||||
|
|
|
@ -226,13 +226,14 @@ void DeviceContextManager::ClearDeviceContexts() {
|
|||
device_contexts_.clear();
|
||||
}
|
||||
|
||||
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
|
||||
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key,
|
||||
string jit_level /* ="" */) {
|
||||
std::string device_context_key_str = device_context_key.ToString();
|
||||
std::string name = device_context_key.device_name_;
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
if (ms_context->backend_policy() == "ge" && (jit_level == kAttrJitLevelO3 || jit_level == "")) {
|
||||
name = "GE";
|
||||
device_context_key_str = "GE_0";
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ class BACKEND_EXPORT DeviceContextManager {
|
|||
public:
|
||||
static DeviceContextManager &GetInstance();
|
||||
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
|
||||
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
|
||||
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key, string jit_level = "");
|
||||
void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
|
||||
void ClearDeviceContexts();
|
||||
void WaitTaskFinishOnDevice() const;
|
||||
|
|
|
@ -357,7 +357,7 @@ bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
|
||||
auto jit_level = GetValue<std::string>(jit_level_value);
|
||||
if (jit_level != kAttrJitLevelO2) {
|
||||
if (jit_level != kAttrJitLevelO2 && jit_level != kAttrJitLevelO3) {
|
||||
MS_LOG(INFO) << "jit_level is " << jit_level << ", task sink is disabled";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -81,6 +81,7 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
|
|||
|
||||
graph->set_run_mode(device::RunMode::kKernelMode);
|
||||
graph->set_is_from_single_op(true);
|
||||
MS_EXCEPTION_IF_NULL(device_context->kernel_executor_);
|
||||
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
|
|
|
@ -1750,5 +1750,16 @@ abstract::AbstractBasePtr AnfAlgo::GetNodeAbstractByIndex(const AnfNodePtr &node
|
|||
}
|
||||
return elements[index];
|
||||
}
|
||||
|
||||
std::string AnfAlgo::GetJitLevel(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (!func_graph->has_attr(kAttrJitLevel)) {
|
||||
MS_LOG(INFO) << "The func_graph:" << func_graph->ToString() << " has no jit_level attr, return default: None.";
|
||||
return "";
|
||||
}
|
||||
auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
|
||||
auto jit_level = GetValue<std::string>(jit_level_value);
|
||||
return jit_level;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -300,6 +300,15 @@ class _MindsporeFunctionExecutor:
|
|||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_executor.set_graph_phase(phase)
|
||||
output = _pynative_executor.grad_ms_function(output, *new_inputs)
|
||||
enable_ge = os.getenv("MS_ENABLE_GE") == "1"
|
||||
if enable_ge and self.jit_config_dict is None:
|
||||
raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
|
||||
if self.jit_config_dict:
|
||||
enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
|
||||
if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
|
||||
raise RuntimeError("GE and jit_level=O3 should be used together, but "
|
||||
"got MS_ENABLE_GE={}, jit_level={}".format(
|
||||
os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -24,11 +24,12 @@ class JitConfig:
|
|||
|
||||
Args:
|
||||
jit_level (str): Option for argument `level` for Optimization of lift graph.
|
||||
Supports ["O0", "O1", "O2"]. Default: "O1".
|
||||
Supports ["O0", "O1", "O2", "O3"]. Default: "O1".
|
||||
|
||||
- "O0": Basic optimization.
|
||||
- "O1": Manual optimization.
|
||||
- "O2": Manual optimization and graph computation fusion.
|
||||
- "O3": Performance optimization, no generalization guaranteed.
|
||||
|
||||
task_sink (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
**kwargs (dict): A dictionary of keyword arguments that the class needs.
|
||||
|
@ -42,8 +43,8 @@ class JitConfig:
|
|||
>>> net.set_jit_config(jitconfig)
|
||||
"""
|
||||
def __init__(self, jit_level="O1", task_sink=True, **kwargs):
|
||||
if jit_level not in ["O0", "O1", "O2"]:
|
||||
raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2'].")
|
||||
if jit_level not in ["O0", "O1", "O2", "O3"]:
|
||||
raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2', 'O3'].")
|
||||
if not isinstance(task_sink, bool):
|
||||
raise TypeError("For 'task_sink' must be bool.")
|
||||
self.jit_config_dict = dict()
|
||||
|
|
|
@ -1603,6 +1603,12 @@ class Cell(Cell_):
|
|||
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
||||
else:
|
||||
self._jit_config_dict = jit_config.jit_config_dict
|
||||
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
||||
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
||||
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
||||
raise RuntimeError("GE and jit_level=O3 should be used together, but "
|
||||
"got MS_ENABLE_GE={}, jie_level={}".format(
|
||||
os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
||||
|
||||
def flatten_weights(self, fusion_size=0):
|
||||
"""
|
||||
|
|
|
@ -375,6 +375,8 @@ class TrainOneStepCell(Cell):
|
|||
group=server_group_name)
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
if network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.network(*inputs)
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import JitConfig
|
||||
|
||||
|
||||
class NetInner(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetInner, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
|
||||
def construct(self, x, y):
|
||||
output = self.addn((x, y))
|
||||
return output
|
||||
|
||||
|
||||
class NetOuter(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetOuter, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.inner_o3 = NetInner()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
return x
|
||||
|
||||
|
||||
def test_ge_graph_mode_with_jit_level_o3():
|
||||
"""
|
||||
Feature: GE with jit_level.
|
||||
Description: Graph Mode jit_level==O3 with GE.
|
||||
Expectation: Run by ge_device_context when jit_level==O3.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
net = NetOuter()
|
||||
net.set_jit_config(JitConfig(jit_level="O3"))
|
||||
output_o3 = net(inputs, inputs)
|
||||
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
|
||||
np.allclose(output_o3.asnumpy(), expected, 1e-05, 1e-05)
|
||||
|
||||
|
||||
def test_ge_graph_mode_with_jit_level_o2():
|
||||
"""
|
||||
Feature: GE with jit_level.
|
||||
Description: Graph Mode jit_level==O2 with GE.
|
||||
Expectation: Raise ValueError when jit_level==O2/O1/O0.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
net = NetOuter()
|
||||
with pytest.raises(RuntimeError):
|
||||
net.set_jit_config(JitConfig(jit_level="O2"))
|
||||
output_o2 = net(inputs, inputs)
|
||||
print("===>output:", output_o2)
|
||||
|
||||
|
||||
def test_ge_graph_mode_without_jit_level():
|
||||
"""
|
||||
Feature: GE with jit_level.
|
||||
Description: Graph Mode jit_level==None with GE.
|
||||
Expectation: Run by ge_device_context without jit_level.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
net = NetOuter()
|
||||
output = net(inputs, inputs)
|
||||
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
|
||||
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ge_graph_mode_with_jit_level_o3()
|
||||
test_ge_graph_mode_with_jit_level_o2()
|
||||
test_ge_graph_mode_without_jit_level()
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import ge_infer_env # pylint: disable=unused-import
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
from mindspore import context, Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.common import JitConfig
|
||||
|
||||
|
||||
class NetInnerO3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetInnerO3, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
|
||||
@ms_function(jit_config=JitConfig(jit_level="O3"))
|
||||
def construct(self, x, y):
|
||||
output = self.addn((x, y))
|
||||
return output
|
||||
|
||||
|
||||
class NetInnerO2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetInnerO2, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
|
||||
@ms_function(jit_config=JitConfig(jit_level="O2"))
|
||||
def construct(self, x, y):
|
||||
output = self.addn((x, y))
|
||||
return output
|
||||
|
||||
|
||||
class TwoO3MsFuncNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TwoO3MsFuncNet, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.inner_o3 = NetInnerO3()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
return x
|
||||
|
||||
|
||||
class O2NestedOneO2OneO3MsFuncNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(O2NestedOneO2OneO3MsFuncNet, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.inner_o2 = NetInnerO2()
|
||||
self.inner_o3 = NetInnerO3()
|
||||
|
||||
@ms_function(jit_config=JitConfig(jit_level="O2"))
|
||||
def construct(self, x, y):
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o2(x, y)
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
return x
|
||||
|
||||
|
||||
class O3NestedTwoO3MsFuncNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(O3NestedTwoO3MsFuncNet, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.inner_o3 = NetInnerO3()
|
||||
|
||||
@ms_function(jit_config=JitConfig(jit_level="O3"))
|
||||
def construct(self, x, y):
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
x = self.addn((x, y))
|
||||
x = self.inner_o3(x, y)
|
||||
return x
|
||||
|
||||
|
||||
def test_pynative_o2_jit_level_ms_function_with_ge():
|
||||
"""
|
||||
Feature: PyNative ms function with GE.
|
||||
Description: jit_level=O2 ms function with GE.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
with pytest.raises(RuntimeError):
|
||||
net = NetInnerO2()
|
||||
output = net(inputs, inputs)
|
||||
print("===>output:", output)
|
||||
|
||||
|
||||
def test_pynative_o3_jit_level_ms_function_with_ge():
|
||||
"""
|
||||
Feature: PyNative ms function with GE.
|
||||
Description: jit_level=O3 ms function with GE.
|
||||
Expectation: Run by ascend_device_context rather than ge_device_context.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
net = NetInnerO3()
|
||||
output = net(inputs, inputs)
|
||||
expected = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]], np.float32)
|
||||
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
|
||||
|
||||
|
||||
def test_pynative_two_o3_jit_level_ms_function_with_ge():
|
||||
"""
|
||||
Feature: PyNative ms function with GE.
|
||||
Description: Two jit_level=O3 ms function with GE.
|
||||
Expectation: Raise RuntimeError when pynative.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
with pytest.raises(RuntimeError):
|
||||
net = TwoO3MsFuncNet()
|
||||
output = net(inputs, inputs)
|
||||
print("===>output:", output)
|
||||
|
||||
|
||||
def test_pynative_o2_nested_one_o2_one_o3_jit_level_ms_function_with_ge():
|
||||
"""
|
||||
Feature: PyNative ms function with GE.
|
||||
Description: O2 nested O2 + O3 ms function with GE.
|
||||
Expectation: Raise ValueError, GE only support O3.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
with pytest.raises(RuntimeError):
|
||||
net = O2NestedOneO2OneO3MsFuncNet()
|
||||
output = net(inputs, inputs)
|
||||
print("===>output:", output)
|
||||
|
||||
|
||||
def test_pynative_o3_nested_two_o3_jit_level_ms_function_with_ge():
|
||||
"""
|
||||
Feature: PyNative ms function with GE.
|
||||
Description: Nested jit_level=O3 ms function with GE.
|
||||
Expectation: Run by ge_device_context.
|
||||
"""
|
||||
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
|
||||
inputs = Tensor(np.ones((3, 3), np.float32))
|
||||
net = O3NestedTwoO3MsFuncNet()
|
||||
output = net(inputs, inputs)
|
||||
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
|
||||
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pynative_o2_jit_level_ms_function_with_ge()
|
||||
test_pynative_o3_jit_level_ms_function_with_ge()
|
||||
test_pynative_two_o3_jit_level_ms_function_with_ge()
|
||||
test_pynative_o2_nested_one_o2_one_o3_jit_level_ms_function_with_ge()
|
||||
test_pynative_o3_nested_two_o3_jit_level_ms_function_with_ge()
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
|
||||
def run_testcase(file_name, case_name=""):
|
||||
log_file = file_name + "_" + case_name + '.log'
|
||||
if case_name == "":
|
||||
ret = os.system(f'{sys.executable} {file_name}.py &> {log_file}')
|
||||
else:
|
||||
ret = os.system(f"{sys.executable} -c 'import {file_name};{file_name}.{case_name}()' &> {log_file}")
|
||||
os.system(f'grep -E "CRITICAL|ERROR|Error" {log_file} -C 3')
|
||||
os.system(f'rm {log_file} -rf')
|
||||
assert ret == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ge_graph_mode_with_jit_level():
|
||||
"""
|
||||
Description: Graph Mode jit_level==O3 with GE.
|
||||
Expectation: Run by ge_device_context when jit_level==O3.
|
||||
"""
|
||||
run_testcase('ge_graph_mode_jit_level')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_with_ge():
|
||||
"""
|
||||
Description: PyNative ms function with GE.
|
||||
Expectation: Run by ge_device_context when jit_level==O3.
|
||||
"""
|
||||
run_testcase('ge_pynative_mode_jit_level')
|
Loading…
Reference in New Issue