graph_kernel support ge
Signed-off-by: wangtongyu <wangtongyu6@huawei.com>
This commit is contained in:
parent
324b3ee65b
commit
669c5db57a
|
@ -57,11 +57,6 @@ IF NOT EXIST "%BUILD_PATH%/mindspore" (
|
|||
md "mindspore"
|
||||
)
|
||||
|
||||
IF "%ENABLE_AKG%" == "1" (
|
||||
echo "enable akg"
|
||||
SET ENABLE_GITEE=ON
|
||||
)
|
||||
|
||||
cd %BUILD_PATH%/mindspore
|
||||
IF "%1%" == "lite" (
|
||||
echo "======Start building MindSpore Lite %VERSION_STR%======"
|
||||
|
|
|
@ -196,11 +196,11 @@ void GraphKernelFlags::SaveJitConfig(const std::map<std::string, std::string> &j
|
|||
|
||||
std::pair<std::string, bool> GraphKernelFlags::GetGraphKernelConfig() {
|
||||
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
|
||||
const auto &jit_config = GetJitConfig();
|
||||
std::string flags = common::GetEnv("MS_DEV_GRAPH_KERNEL_FLAGS");
|
||||
if (flags != "") {
|
||||
return std::make_pair(flags, false);
|
||||
}
|
||||
const auto &jit_config = GetJitConfig();
|
||||
if (jit_config.find("graph_kernel_flags") != jit_config.end()) {
|
||||
flags = jit_config.at("graph_kernel_flags");
|
||||
}
|
||||
|
|
|
@ -803,7 +803,7 @@ update_submodule()
|
|||
cd "${BASEPATH}/graphengine"
|
||||
git submodule update --init 910/metadef
|
||||
cd "${BASEPATH}"
|
||||
if [[ ("X$MSLITE_ENABLE_GRAPH_KERNEL" = "Xon" && ("${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}" == "on") || ("${MSLITE_ENABLE_CLOUD_INFERENCE}" == "on") ) ]]; then
|
||||
if [[ ("X$ENABLE_AKG" = "Xon" && ("${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}" == "on") || ("${MSLITE_ENABLE_CLOUD_INFERENCE}" == "on") ) ]]; then
|
||||
if [[ ("${MSLITE_ENABLE_ACL}" == "on") ]]; then
|
||||
git submodule update --init akg
|
||||
else
|
||||
|
|
|
@ -47,6 +47,22 @@ constexpr auto kDumpMode = "dump_mode";
|
|||
constexpr auto kProfiling = "profiler";
|
||||
constexpr auto kDataFlowGraphType = "data_flow";
|
||||
constexpr auto kCustomInputSize = 2;
|
||||
constexpr auto kGraphKernelParam = "graph_kernel_param";
|
||||
|
||||
std::shared_ptr<ConverterPara> ParseGraphKernelConfigs(const ConfigInfos &maps) {
|
||||
if (maps.find(kGraphKernelParam) == maps.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto param = std::make_shared<ConverterPara>();
|
||||
const auto &gk_map = maps.at(kGraphKernelParam);
|
||||
std::stringstream oss;
|
||||
for (const auto &item : gk_map) {
|
||||
oss << "--" << item.first << "=" << item.second << " ";
|
||||
}
|
||||
param->device = "Ascend";
|
||||
param->graphKernelParam.graph_kernel_flags = oss.str();
|
||||
return param;
|
||||
}
|
||||
|
||||
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
|
@ -264,7 +280,8 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map
|
|||
return false;
|
||||
}
|
||||
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
|
||||
if (GraphKernelOptimize(anf_graph, nullptr) != lite::RET_OK) {
|
||||
auto param = ParseGraphKernelConfigs(config_infos_);
|
||||
if (GraphKernelOptimize(anf_graph, param) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Run graphkernel optimization failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http:__www.apache.org_licenses_LICENSE-2.0
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http:__www.apache.org_licenses_LICENSE-2.0
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
|
|
@ -102,8 +102,8 @@ Flags::Flags() {
|
|||
#else
|
||||
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
#endif
|
||||
AddFlag(&Flags::optimizeStr, "optimize",
|
||||
"The type of optimization. none | general | gpu_oriented | ascend_oriented | cpu_oriented", "general");
|
||||
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | gpu_oriented | ascend_oriented",
|
||||
"general");
|
||||
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
|
||||
}
|
||||
|
||||
|
@ -288,12 +288,8 @@ int Flags::InitOptimize() {
|
|||
} else if (this->optimizeStr == "ascend_oriented") {
|
||||
this->disableFusion = false;
|
||||
this->device = "Ascend";
|
||||
} else if (this->optimizeStr == "cpu_oriented") {
|
||||
this->disableFusion = false;
|
||||
this->device = "CPU";
|
||||
} else if (!this->optimizeStr.empty()) {
|
||||
std::cerr << "INPUT ILLEGAL: optimize must be none|general|gpu_oriented|ascend_oriented|cpu_oriented|general"
|
||||
<< std::endl;
|
||||
std::cerr << "INPUT ILLEGAL: optimize must be none|general|gpu_oriented|ascend_oriented" << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -165,12 +165,7 @@ lite::STATUS GraphKernelOptimize(const FuncGraphPtr &func_graph, const std::shar
|
|||
}
|
||||
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
MS_LOG(INFO) << "Run graphkernel optimization begin.";
|
||||
auto p = param;
|
||||
if (p == nullptr) {
|
||||
p = std::make_shared<ConverterPara>();
|
||||
p->device = "Ascend";
|
||||
}
|
||||
graphkernel::GraphKernelOptimizer(p).Run(func_graph);
|
||||
graphkernel::GraphKernelOptimizer(param).Run(func_graph);
|
||||
MS_LOG(INFO) << "Run graphkernel optimization end.";
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
|
Loading…
Reference in New Issue