unify infer and train compile binary
This commit is contained in:
parent
d1ca2fdae0
commit
54fc038f0a
|
@ -2,7 +2,6 @@ option(ENABLE_D "Enable d" OFF)
|
|||
option(ENABLE_GPU "Enable gpu" OFF)
|
||||
option(ENABLE_CPU "Enable cpu" OFF)
|
||||
option(ENABLE_MINDDATA "Enable minddata compile" OFF)
|
||||
option(ENABLE_TRAIN "Enable ge train, default off(only infer)" OFF)
|
||||
option(ENABLE_TESTCASES "Run testcases switch, default off" OFF)
|
||||
option(ENABLE_CPP_ST "Run cpp st testcases switch, default off" OFF)
|
||||
option(DEBUG_MODE "Debug mode, default off" OFF)
|
||||
|
@ -95,12 +94,6 @@ if(ENABLE_D)
|
|||
add_compile_definitions(CUSTOM_OP)
|
||||
endif()
|
||||
|
||||
if(ENABLE_TRAIN)
|
||||
add_compile_definitions(ENABLE_TRAIN=1)
|
||||
else()
|
||||
add_compile_definitions(ENABLE_TRAIN=0)
|
||||
endif()
|
||||
|
||||
if(USE_GLOG)
|
||||
add_compile_definitions(USE_GLOG)
|
||||
endif()
|
||||
|
|
|
@ -49,11 +49,7 @@ bool CreateSessionAndGraphRunner() {
|
|||
options["ge.trainFlag"] = "0";
|
||||
options["ge.enablePrintOpPass"] = "0";
|
||||
sess = transform::GraphRunner::NewSession(options);
|
||||
if (sess == nullptr) {
|
||||
MS_LOG(WARNING) << "Init data graph failed, because of create Ge session failed";
|
||||
} else {
|
||||
transform::DfGraphManager::GetInstance().SetGeSession(sess);
|
||||
}
|
||||
transform::DfGraphManager::GetInstance().SetGeSession(sess);
|
||||
}
|
||||
|
||||
transform::GraphRunnerOptions options;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -129,13 +130,14 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
|
|||
return false;
|
||||
}
|
||||
|
||||
#if ENABLE_TRAIN
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
#else
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
#endif
|
||||
auto training = ConfigManager::GetInstance().training();
|
||||
if (training) {
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
} else {
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
|
||||
if (CreateSessionAndGraphRunner(static_cast<bool>(ENABLE_TRAIN)) != Status::SUCCESS) {
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -244,6 +246,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
|
|||
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase);
|
||||
}
|
||||
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
|
||||
ConfigManager::GetInstance().set_training(anf_graph->has_flag("training"));
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
draw::Draw("anf_graph.dot", anf_graph); // for debug
|
||||
|
@ -256,13 +259,14 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
#if ENABLE_TRAIN
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
#else
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
#endif
|
||||
auto training = ConfigManager::GetInstance().training();
|
||||
if (training) {
|
||||
(void)setenv("GE_TRAIN", "1", 1);
|
||||
} else {
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
}
|
||||
|
||||
if (CreateSessionAndGraphRunner(static_cast<bool>(ENABLE_TRAIN)) != Status::SUCCESS) {
|
||||
if (CreateSessionAndGraphRunner(training) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -489,16 +493,6 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
|
|||
}
|
||||
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
|
||||
|
||||
#ifdef ENABLE_INFER
|
||||
// Now don't use the graph because the exec ge function don't take effect
|
||||
MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph);
|
||||
if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) {
|
||||
MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries";
|
||||
ConfigManager::GetInstance().ResetConfig();
|
||||
return py::none();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<py::object> ret_val = std::make_shared<py::object>();
|
||||
// We will not execute graph when output is constant or just input itself.
|
||||
if (IsGraphOutputValueNodeOrParameter(info.at(phase)->func_graph->output(), args, ret_val)) {
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -36,7 +37,6 @@
|
|||
#include "transform/graph_ir/util.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "transform/graph_ir/df_graph_manager.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "transform/graph_ir/op_adapter.h"
|
||||
#include "graph/operator_reg.h"
|
||||
#ifdef OPEN_SOURCE
|
||||
|
@ -56,15 +56,7 @@ class DfGraphConvertor {
|
|||
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
|
||||
#if (defined ENABLE_D) && (!defined ENABLE_INFER)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->backend_policy() == "ge") {
|
||||
training_ = ENABLE_TRAIN;
|
||||
}
|
||||
#else
|
||||
training_ = anf_graph->has_flag("training");
|
||||
#endif
|
||||
distribute_ = anf_graph->has_flag("broadcast_flag");
|
||||
if (anf_graph->has_flag("broadcast_flag")) {
|
||||
ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION);
|
||||
|
|
|
@ -117,6 +117,10 @@ class ConfigManager {
|
|||
|
||||
void set_gpu_loopsink_size(const int64_t size) { gpu_loopsink_size_ = size; }
|
||||
|
||||
bool training() const { return training_; }
|
||||
|
||||
void set_training(const bool training) { training_ = training; }
|
||||
|
||||
private:
|
||||
ConfigManager() = default;
|
||||
~ConfigManager() = default;
|
||||
|
@ -130,6 +134,7 @@ class ConfigManager {
|
|||
std::map<std::string, int16_t> queue_info_map;
|
||||
std::string dataset_phase_{""};
|
||||
int64_t gpu_loopsink_size_{1};
|
||||
bool training_{false};
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <thread>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#ifndef NO_DLIB
|
||||
|
@ -194,9 +195,10 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
#if ENABLE_TRAIN == 1
|
||||
(*ge_options)["ge.graphRunMode"] = "1";
|
||||
#endif
|
||||
if (ConfigManager::GetInstance().training()) {
|
||||
(*ge_options)["ge.graphRunMode"] = "1";
|
||||
}
|
||||
|
||||
SetDisableReuseMemoryFlag(ge_options);
|
||||
SetHcclOptions(ms_context_ptr, ge_options);
|
||||
|
||||
|
|
|
@ -35,9 +35,6 @@ build_mindspore()
|
|||
if [[ -n "$ENABLE_BACKEND" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${ENABLE_BACKEND}=ON"
|
||||
fi
|
||||
if [[ -n "$TRAIN_MODE" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${TRAIN_MODE}=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_SYM_FILE" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SYM_FILE=ON"
|
||||
fi
|
||||
|
|
|
@ -29,7 +29,6 @@ init_default_options()
|
|||
export RUN_TESTCASES="off"
|
||||
export RUN_CPP_ST_TESTS="off"
|
||||
export ENABLE_BACKEND=""
|
||||
export TRAIN_MODE="INFER"
|
||||
export ENABLE_ASAN="off"
|
||||
export ENABLE_PROFILE="off"
|
||||
export INC_BUILD="off"
|
||||
|
|
|
@ -35,16 +35,6 @@ build_option_proc_l()
|
|||
export ENABLE_PYTHON="$OPTARG"
|
||||
}
|
||||
|
||||
build_option_proc_m()
|
||||
{
|
||||
if [[ "X$OPTARG" != "Xinfer" && "X$OPTARG" != "Xtrain" ]]; then
|
||||
echo "Invalid value ${OPTARG} for option -m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
export TRAIN_MODE=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]')
|
||||
}
|
||||
|
||||
build_option_proc_s()
|
||||
{
|
||||
check_on_off $OPTARG s
|
||||
|
|
|
@ -20,7 +20,7 @@ set -e
|
|||
process_options()
|
||||
{
|
||||
# Process the options
|
||||
while getopts 'drvj:c:t:hb:s:a:g:p:ie:m:l:I:RP:D:zM:V:K:B:En:A:S:k:W:F:H:L:y' opt
|
||||
while getopts 'drvj:c:t:hb:s:a:g:p:ie:l:I:RP:D:zM:V:K:B:En:A:S:k:W:F:H:L:y' opt
|
||||
do
|
||||
CASE_SENSIVE_ARG=${OPTARG}
|
||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||
|
@ -55,8 +55,6 @@ process_options()
|
|||
build_option_proc_l ;;
|
||||
i)
|
||||
export INC_BUILD="on" ;;
|
||||
m)
|
||||
build_option_proc_m ;;
|
||||
s)
|
||||
build_option_proc_s ;;
|
||||
R)
|
||||
|
|
|
@ -260,12 +260,7 @@ if(MINDSPORE_PROTO_LIST)
|
|||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
if(ENABLE_TRAIN)
|
||||
target_link_libraries(ut_tests PRIVATE graph ge_runner)
|
||||
else()
|
||||
target_link_libraries(ut_tests PRIVATE graph ge_client)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ut_tests PRIVATE graph ge_runner ge_client)
|
||||
target_link_libraries(mindspore PRIVATE tsdclient)
|
||||
endif()
|
||||
|
||||
|
|
Loading…
Reference in New Issue