GE support PyNative ms_function.

Update Chinese docs.

Opt
This commit is contained in:
liangzelang 2022-10-25 19:53:15 +08:00
parent 666428a765
commit bcbf41238c
17 changed files with 369 additions and 11 deletions

View File

@ -9,11 +9,12 @@ mindspore.JitConfig
- 这是一个实验性接口,后续可能删除或修改。
参数:
- **jit_level** (str) - 设置编译优化的级别,支持["O0", "O1", "O2"]。默认值:"O1"。
- **jit_level** (str) - 设置编译优化的级别,支持["O0", "O1", "O2", "O3"]。默认值:"O1"。
- "O0": 基础优化。
- "O1": 手动优化。
- "O2": 手动优化与图算优化结合。
- "O3": 性能优化,无法保证泛化性。
- **task_sink** (bool) - 数据是否直接下沉至处理器进行处理。默认值True。
- **kwargs** (dict) - 关键字参数字典。

View File

@ -311,9 +311,11 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
func_graph_to_kernel_graph_ids_.clear();
control_nodes_.clear();
auto jit_level = common::AnfAlgo::GetJitLevel(func_graph);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}, jit_level);
MS_EXCEPTION_IF_NULL(device_context);
device_context->Initialize();
bool all_support = device_context->PartitionGraph(func_graph);
if (all_support) {
auto run_mode = device_context->GetRunMode(func_graph);
@ -331,7 +333,8 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
// Construct the graph compiler info.
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_info);
if (ms_execution_mode_ == kGraphMode &&
if ((ms_execution_mode_ == kGraphMode ||
(ms_execution_mode_ == kPynativeMode && jit_level == "O3" && context_ptr->backend_policy() == "ge")) &&
((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
// Transform graph to actor DAG, and schedule the actor DAG.
ParseControlNodes(*graph_compiler_info);

View File

@ -286,6 +286,9 @@ class COMMON_EXPORT AnfAlgo {
static std::string GetTensorValueString(const tensor::TensorPtr &tensor);
static abstract::AbstractBasePtr GetNodeAbstractByIndex(const AnfNodePtr &node, size_t index);
// Get jit level from func_graph
static std::string GetJitLevel(const FuncGraphPtr &func_graph);
};
} // namespace common
} // namespace mindspore

View File

@ -681,6 +681,7 @@ constexpr auto kAttrNeedGradFlagOfInputs = "need_grad_flag_of_inputs";
constexpr auto kAttrIsCNodeNeedGrad = "is_cnode_need_grad";
constexpr auto kAttrJitLevel = "jit_level";
constexpr auto kAttrJitLevelO2 = "O2";
constexpr auto kAttrJitLevelO3 = "O3";
constexpr auto kAttrCellJitConfigDict = "_jit_config_dict";
constexpr auto kAttrBinaryOutput = "binary_output";
constexpr auto kAttrMinLength = "minlength";

View File

@ -392,8 +392,9 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
}
// todo: waiting for GraphExecutor
auto jit_level = common::AnfAlgo::GetJitLevel(func_graph);
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->backend_policy() == "ge") {
if (MsContext::GetInstance()->backend_policy() == "ge" && (jit_level == "O3" || jit_level == "")) {
auto manager = MakeManager();
MS_EXCEPTION_IF_NULL(manager);
for (const auto &graph : all_graphs) {

View File

@ -226,13 +226,14 @@ void DeviceContextManager::ClearDeviceContexts() {
device_contexts_.clear();
}
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key,
string jit_level /* ="" */) {
std::string device_context_key_str = device_context_key.ToString();
std::string name = device_context_key.device_name_;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->backend_policy() == "ge") {
if (ms_context->backend_policy() == "ge" && (jit_level == kAttrJitLevelO3 || jit_level == "")) {
name = "GE";
device_context_key_str = "GE_0";
}

View File

@ -48,7 +48,7 @@ class BACKEND_EXPORT DeviceContextManager {
public:
static DeviceContextManager &GetInstance();
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key, string jit_level = "");
void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
void ClearDeviceContexts();
void WaitTaskFinishOnDevice() const;

View File

@ -357,7 +357,7 @@ bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
}
auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
auto jit_level = GetValue<std::string>(jit_level_value);
if (jit_level != kAttrJitLevelO2) {
if (jit_level != kAttrJitLevelO2 && jit_level != kAttrJitLevelO3) {
MS_LOG(INFO) << "jit_level is " << jit_level << ", task sink is disabled";
return false;
}

View File

@ -81,6 +81,7 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
graph->set_run_mode(device::RunMode::kKernelMode);
graph->set_is_from_single_op(true);
MS_EXCEPTION_IF_NULL(device_context->kernel_executor_);
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
auto deprecated_kernel_executor =
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());

View File

@ -1750,5 +1750,16 @@ abstract::AbstractBasePtr AnfAlgo::GetNodeAbstractByIndex(const AnfNodePtr &node
}
return elements[index];
}
std::string AnfAlgo::GetJitLevel(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (!func_graph->has_attr(kAttrJitLevel)) {
MS_LOG(INFO) << "The func_graph:" << func_graph->ToString() << " has no jit_level attr, return default: None.";
return "";
}
auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
auto jit_level = GetValue<std::string>(jit_level_value);
return jit_level;
}
} // namespace common
} // namespace mindspore

View File

@ -300,6 +300,15 @@ class _MindsporeFunctionExecutor:
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_executor.set_graph_phase(phase)
output = _pynative_executor.grad_ms_function(output, *new_inputs)
enable_ge = os.getenv("MS_ENABLE_GE") == "1"
if enable_ge and self.jit_config_dict is None:
raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
if self.jit_config_dict:
enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
raise RuntimeError("GE and jit_level=O3 should be used together, but "
"got MS_ENABLE_GE={}, jit_level={}".format(
os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
return output

View File

@ -24,11 +24,12 @@ class JitConfig:
Args:
jit_level (str): Option for argument `level` for Optimization of lift graph.
Supports ["O0", "O1", "O2"]. Default: "O1".
Supports ["O0", "O1", "O2", "O3"]. Default: "O1".
- "O0": Basic optimization.
- "O1": Manual optimization.
- "O2": Manual optimization and graph computation fusion.
- "O3": Performance optimization, no generalization guaranteed.
task_sink (bool): Determines whether to pass the data through dataset channel. Default: True.
**kwargs (dict): A dictionary of keyword arguments that the class needs.
@ -42,8 +43,8 @@ class JitConfig:
>>> net.set_jit_config(jitconfig)
"""
def __init__(self, jit_level="O1", task_sink=True, **kwargs):
if jit_level not in ["O0", "O1", "O2"]:
raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2'].")
if jit_level not in ["O0", "O1", "O2", "O3"]:
raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2', 'O3'].")
if not isinstance(task_sink, bool):
raise TypeError("For 'task_sink' must be bool.")
self.jit_config_dict = dict()

View File

@ -1603,6 +1603,12 @@ class Cell(Cell_):
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
else:
self._jit_config_dict = jit_config.jit_config_dict
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
raise RuntimeError("GE and jit_level=O3 should be used together, but "
"got MS_ENABLE_GE={}, jie_level={}".format(
os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
def flatten_weights(self, fusion_size=0):
"""

View File

@ -375,6 +375,8 @@ class TrainOneStepCell(Cell):
group=server_group_name)
else:
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
if network.jit_config_dict:
self._jit_config_dict = network.jit_config_dict
def construct(self, *inputs):
loss = self.network(*inputs)

View File

@ -0,0 +1,97 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import ge_infer_env # pylint: disable=unused-import
from mindspore import nn
from mindspore import ops
from mindspore import context, Tensor
from mindspore.common import JitConfig
class NetInner(nn.Cell):
def __init__(self):
super(NetInner, self).__init__()
self.addn = ops.AddN()
def construct(self, x, y):
output = self.addn((x, y))
return output
class NetOuter(nn.Cell):
def __init__(self):
super(NetOuter, self).__init__()
self.addn = ops.AddN()
self.inner_o3 = NetInner()
def construct(self, x, y):
x = self.addn((x, y))
x = self.inner_o3(x, y)
x = self.addn((x, y))
x = self.inner_o3(x, y)
return x
def test_ge_graph_mode_with_jit_level_o3():
"""
Feature: GE with jit_level.
Description: Graph Mode jit_level==O3 with GE.
Expectation: Run by ge_device_context when jit_level==O3.
"""
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
net = NetOuter()
net.set_jit_config(JitConfig(jit_level="O3"))
output_o3 = net(inputs, inputs)
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
np.allclose(output_o3.asnumpy(), expected, 1e-05, 1e-05)
def test_ge_graph_mode_with_jit_level_o2():
"""
Feature: GE with jit_level.
Description: Graph Mode jit_level==O2 with GE.
Expectation: Raise ValueError when jit_level==O2/O1/O0.
"""
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
net = NetOuter()
with pytest.raises(RuntimeError):
net.set_jit_config(JitConfig(jit_level="O2"))
output_o2 = net(inputs, inputs)
print("===>output:", output_o2)
def test_ge_graph_mode_without_jit_level():
"""
Feature: GE with jit_level.
Description: Graph Mode jit_level==None with GE.
Expectation: Run by ge_device_context without jit_level.
"""
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
net = NetOuter()
output = net(inputs, inputs)
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
if __name__ == "__main__":
test_ge_graph_mode_with_jit_level_o3()
test_ge_graph_mode_with_jit_level_o2()
test_ge_graph_mode_without_jit_level()

View File

@ -0,0 +1,169 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import ge_infer_env # pylint: disable=unused-import
from mindspore import nn
from mindspore import ops
from mindspore import context, Tensor
from mindspore import ms_function
from mindspore.common import JitConfig
class NetInnerO3(nn.Cell):
def __init__(self):
super(NetInnerO3, self).__init__()
self.addn = ops.AddN()
@ms_function(jit_config=JitConfig(jit_level="O3"))
def construct(self, x, y):
output = self.addn((x, y))
return output
class NetInnerO2(nn.Cell):
def __init__(self):
super(NetInnerO2, self).__init__()
self.addn = ops.AddN()
@ms_function(jit_config=JitConfig(jit_level="O2"))
def construct(self, x, y):
output = self.addn((x, y))
return output
class TwoO3MsFuncNet(nn.Cell):
def __init__(self):
super(TwoO3MsFuncNet, self).__init__()
self.addn = ops.AddN()
self.inner_o3 = NetInnerO3()
def construct(self, x, y):
x = self.addn((x, y))
x = self.inner_o3(x, y)
x = self.addn((x, y))
x = self.inner_o3(x, y)
return x
class O2NestedOneO2OneO3MsFuncNet(nn.Cell):
def __init__(self):
super(O2NestedOneO2OneO3MsFuncNet, self).__init__()
self.addn = ops.AddN()
self.inner_o2 = NetInnerO2()
self.inner_o3 = NetInnerO3()
@ms_function(jit_config=JitConfig(jit_level="O2"))
def construct(self, x, y):
x = self.addn((x, y))
x = self.inner_o2(x, y)
x = self.addn((x, y))
x = self.inner_o3(x, y)
return x
class O3NestedTwoO3MsFuncNet(nn.Cell):
def __init__(self):
super(O3NestedTwoO3MsFuncNet, self).__init__()
self.addn = ops.AddN()
self.inner_o3 = NetInnerO3()
@ms_function(jit_config=JitConfig(jit_level="O3"))
def construct(self, x, y):
x = self.addn((x, y))
x = self.inner_o3(x, y)
x = self.addn((x, y))
x = self.inner_o3(x, y)
return x
def test_pynative_o2_jit_level_ms_function_with_ge():
"""
Feature: PyNative ms function with GE.
Description: jit_level=O2 ms function with GE.
Expectation: Raise ValueError.
"""
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
with pytest.raises(RuntimeError):
net = NetInnerO2()
output = net(inputs, inputs)
print("===>output:", output)
def test_pynative_o3_jit_level_ms_function_with_ge():
"""
Feature: PyNative ms function with GE.
Description: jit_level=O3 ms function with GE.
Expectation: Run by ascend_device_context rather than ge_device_context.
"""
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
net = NetInnerO3()
output = net(inputs, inputs)
expected = np.array([[2, 2, 2], [2, 2, 2], [2, 2, 2]], np.float32)
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
def test_pynative_two_o3_jit_level_ms_function_with_ge():
"""
Feature: PyNative ms function with GE.
Description: Two jit_level=O3 ms function with GE.
Expectation: Raise RuntimeError when pynative.
"""
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
with pytest.raises(RuntimeError):
net = TwoO3MsFuncNet()
output = net(inputs, inputs)
print("===>output:", output)
def test_pynative_o2_nested_one_o2_one_o3_jit_level_ms_function_with_ge():
"""
Feature: PyNative ms function with GE.
Description: O2 nested O2 + O3 ms function with GE.
Expectation: Raise ValueError, GE only support O3.
"""
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
with pytest.raises(RuntimeError):
net = O2NestedOneO2OneO3MsFuncNet()
output = net(inputs, inputs)
print("===>output:", output)
def test_pynative_o3_nested_two_o3_jit_level_ms_function_with_ge():
"""
Feature: PyNative ms function with GE.
Description: Nested jit_level=O3 ms function with GE.
Expectation: Run by ge_device_context.
"""
context.set_context(device_target="Ascend", mode=context.PYNATIVE_MODE)
inputs = Tensor(np.ones((3, 3), np.float32))
net = O3NestedTwoO3MsFuncNet()
output = net(inputs, inputs)
expected = np.array([[5, 5, 5], [5, 5, 5], [5, 5, 5]], np.float32)
np.allclose(output.asnumpy(), expected, 1e-05, 1e-05)
if __name__ == "__main__":
test_pynative_o2_jit_level_ms_function_with_ge()
test_pynative_o3_jit_level_ms_function_with_ge()
test_pynative_two_o3_jit_level_ms_function_with_ge()
test_pynative_o2_nested_one_o2_one_o3_jit_level_ms_function_with_ge()
test_pynative_o3_nested_two_o3_jit_level_ms_function_with_ge()

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import sys
import pytest
def run_testcase(file_name, case_name=""):
log_file = file_name + "_" + case_name + '.log'
if case_name == "":
ret = os.system(f'{sys.executable} {file_name}.py &> {log_file}')
else:
ret = os.system(f"{sys.executable} -c 'import {file_name};{file_name}.{case_name}()' &> {log_file}")
os.system(f'grep -E "CRITICAL|ERROR|Error" {log_file} -C 3')
os.system(f'rm {log_file} -rf')
assert ret == 0
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ge_graph_mode_with_jit_level():
"""
Description: Graph Mode jit_level==O3 with GE.
Expectation: Run by ge_device_context when jit_level==O3.
"""
run_testcase('ge_graph_mode_jit_level')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pynative_ms_function_with_ge():
"""
Description: PyNative ms function with GE.
Expectation: Run by ge_device_context when jit_level==O3.
"""
run_testcase('ge_pynative_mode_jit_level')