add pre compile process

This commit is contained in:
wangcong 2020-06-17 20:59:08 +08:00
parent b6209eb841
commit aed393d50f
8 changed files with 27 additions and 13 deletions

View File

@ -28,8 +28,8 @@ build_in_impl_path = get_build_in_impl_path()
# op function list
op_build = "compile"
op_pre_build = "pre_build"
fusion_type_map = {'Convolution': 0, 'ElemWise': 1, 'CommReduce': 2,
'Segment': 3, 'Opaque': 4}
fusion_pattern_start_flag = "fusion_pattern_start"
fusion_pattern_end_flag = "fusion_pattern_end"
def _initialize(impl_path):
"""Initialize"""
@ -43,7 +43,6 @@ def _initialize(impl_path):
sys.path.insert(0, op_module_name)
def build_op(build_type, json_str):
"""
call op functions with function name and input args json_str
@ -169,7 +168,5 @@ def compile_with_json(json_str):
if __name__ == "__main__":
in_args = sys.stdin.readline()
result = compile_with_json(in_args)
if result in fusion_type_map:
exit(fusion_type_map[result])
else:
exit(100)
sys.stdout.write(fusion_pattern_start_flag + str(result) + fusion_pattern_end_flag)
sys.stdout.flush()

View File

@ -88,10 +88,10 @@ def run_compiler(op_json):
try:
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
completed_object = subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
text=True, capture_output=True, check=False)
text=True, capture_output=True, check=True)
if completed_object:
code = completed_object.returncode
return "Success", str(code)
out = completed_object.stdout
return "Success", out
except subprocess.TimeoutExpired:
tb = traceback.format_exc()
return "TBEException", "PreCompileTimeOut: " + tb + "\ninput_args: " + op_json

View File

@ -73,7 +73,8 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
switch (kernel_type) {
case KernelType::TBE_KERNEL: {
if (AnfAlgo::GetKernelMod(anf_node) == nullptr) {
if (AnfAlgo::GetKernelMod(anf_node) == nullptr &&
AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) {
tbe_nodes.push_back(anf_node);
}
break;

View File

@ -45,6 +45,7 @@ enum FusionType {
COMMREDUCE,
SEGMENT,
OPAQUE,
DYNAMIC,
UNKNOWN_FUSION_TYPE = -1,
};
enum OpPattern {

View File

@ -63,7 +63,7 @@ const std::unordered_map<std::string, size_t> type_nbyte_maps = {
const std::unordered_map<std::string, FusionType> fusion_type_maps = {
{"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
{"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
{"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE},
};
TypeId DtypeToTypeId(const std::string &dtypes) {

View File

@ -205,6 +205,20 @@ void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::stri
if (task_iter == pre_task_map_.end()) {
MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id;
}
auto node = task_iter->second;
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
std::string start_flag = "fusion_pattern_start";
std::string end_flag = "fusion_pattern_end";
int start = pre_build_result.find(start_flag);
int end = pre_build_result.find(end_flag);
if (start != -1 && end != -1) {
std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size());
transform(result.begin(), result.end(), result.begin(), ::toupper);
FusionType fusion_type = tbe::GetFusionType(result);
builder->SetFusionType(fusion_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
(void)pre_task_map_.erase(task_iter);
}

View File

@ -563,6 +563,7 @@ void AscendSession::InitRuntimeResource() {
}
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
device::ascend::KernelPreBuild(kernel_graph.get());
MS_LOG(INFO) << "HardwareOptimize start!";
opt::AscendBackendOptimization(kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);

View File

@ -17,7 +17,7 @@
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
matmul_op_info = TBERegOp("MatMul") \
.fusion_type("OPAQUE") \
.fusion_type("DYNAMIC") \
.async_flag(False) \
.binfile_name("matmul.so") \
.compute_cost(10) \