!13421 Register AKG kernel

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-03-18 14:27:42 +08:00 committed by Gitee
commit d6ddd4a107
6 changed files with 53 additions and 82 deletions

View File

@ -10,7 +10,7 @@ if(ENABLE_ACL)
include_directories(${CMAKE_SOURCE_DIR}/graphengine/ge)
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"python_utils.cc"
"akg_kernel_register.cc"
"model/acl/*.cc"
"model/model_converter_utils/*.cc"
"graph/acl/*.cc"
@ -19,11 +19,12 @@ endif()
if(ENABLE_D)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc")
"akg_kernel_register.cc" "model/ms/*.cc" "graph/ascend/*.cc")
endif()
if(ENABLE_GPU)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc")
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"akg_kernel_register.cc" "model/ms/*.cc" "graph/gpu/*.cc")
endif()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc

View File

@ -13,52 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/python_utils.h"
#include "cxx_api/akg_kernel_register.h"
#include <dlfcn.h>
#include <mutex>
#include <vector>
#include <memory>
#include <string>
#include <fstream>
#include "mindspore/core/utils/ms_context.h"
#include "pybind11/pybind11.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace py = pybind11;
static std::mutex init_mutex;
static bool Initialized = false;
namespace mindspore {
static void RegAllOpFromPython() {
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize();
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
MS_EXCEPTION_IF_NULL(c_expression);
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
MS_EXCEPTION_IF_NULL(c_expression_dict);
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
MS_EXCEPTION_IF_NULL(op_info_loader_class);
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
MS_EXCEPTION_IF_NULL(op_info_loader);
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
for (auto op_info : *all_ops_info) {
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
}
all_ops_info->clear();
delete all_ops_info;
Py_DECREF(op_info_loader);
Py_DECREF(op_info_loader_class);
Py_DECREF(c_expression_dict);
Py_DECREF(c_expression);
}
static bool RegAllOpFromFile() {
Dl_info info;
int dl_ret = dladdr(reinterpret_cast<void *>(RegAllOpFromFile), &info);
@ -111,36 +77,10 @@ void RegAllOp() {
}
bool ret = RegAllOpFromFile();
if (!ret) {
MS_LOG(INFO) << "Reg all op from file failed, start to reg from python.";
RegAllOpFromPython();
MS_LOG(ERROR) << "Register operators failed. The package may damaged or file is missing.";
return;
}
Initialized = true;
}
bool PythonIsInited() { return Py_IsInitialized() != 0; }
void InitPython() {
if (!PythonIsInited()) {
Py_Initialize();
}
}
void FinalizePython() {
if (PythonIsInited()) {
Py_Finalize();
}
}
PythonEnvGuard::PythonEnvGuard() {
origin_init_status_ = PythonIsInited();
InitPython();
}
PythonEnvGuard::~PythonEnvGuard() {
// finalize when init by this
if (!origin_init_status_) {
FinalizePython();
}
}
} // namespace mindspore

View File

@ -13,22 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
#ifndef MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_
#define MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_
namespace mindspore {
void RegAllOp();
bool PythonIsInited();
void InitPython();
void FinalizePython();
class PythonEnvGuard {
public:
PythonEnvGuard();
~PythonEnvGuard();
private:
bool origin_init_status_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
#endif // MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_

View File

@ -17,7 +17,7 @@
#include <algorithm>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/python_utils.h"
#include "cxx_api/akg_kernel_register.h"
#include "utils/log_adapter.h"
#include "utils/context/context_extends.h"
#include "mindspore/core/base/base_ref_utils.h"
@ -27,6 +27,7 @@
#include "runtime/dev.h"
#include "pipeline/jit/pipeline.h"
#include "frontend/parallel/step_parallel.h"
#include "pybind11/pybind11.h"
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
@ -380,4 +381,30 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
PythonEnvGuard::PythonEnvGuard() {
origin_init_status_ = PythonIsInited();
InitPython();
}
PythonEnvGuard::~PythonEnvGuard() {
// finalize when init by this
if (!origin_init_status_) {
FinalizePython();
}
}
bool PythonEnvGuard::PythonIsInited() { return Py_IsInitialized() != 0; }
void PythonEnvGuard::InitPython() {
if (!PythonIsInited()) {
Py_Initialize();
}
}
void PythonEnvGuard::FinalizePython() {
if (PythonIsInited()) {
Py_Finalize();
}
}
} // namespace mindspore

View File

@ -79,5 +79,17 @@ class AscendGraphImpl::MsEnvGuard {
Status errno_;
uint32_t device_id_;
};
class PythonEnvGuard {
public:
PythonEnvGuard();
~PythonEnvGuard();
private:
bool PythonIsInited();
void InitPython();
void FinalizePython();
bool origin_init_status_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/akg_kernel_register.h"
#include "utils/log_adapter.h"
#include "mindspore/core/base/base_ref_utils.h"
#include "backend/session/session_factory.h"
@ -43,6 +44,8 @@ Status GPUGraphImpl::InitEnv() {
return kSuccess;
}
// Register op implemented with AKG.
RegAllOp();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";