diff --git a/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py index c5ccc13a4cd..16ddbcb5870 100644 --- a/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +++ b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,22 +18,8 @@ import shutil import subprocess import sys from multiprocessing import Pool, cpu_count -import importlib - -def get_akg_path(): - """get akg directory base path""" - search_res = importlib.util.find_spec("mindspore") - if search_res is None: - raise RuntimeError("Cannot find mindspore module!") - - res_path = search_res.origin - find_pos = res_path.find("__init__.py") - if find_pos == -1: - raise RuntimeError("Find module mindspore origin file failed!") - akg_path = "{}_akg".format(res_path[:find_pos]) - if not os.path.isdir(akg_path): - raise RuntimeError("Cannot find akg from mindspore module!") - return akg_path +from mindspore import log as logger +from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_akg_path def copy_json(pid_path, ppid_path): """ @@ -43,7 +29,8 @@ def copy_json(pid_path, ppid_path): os.mkdir(ppid_path) json_files = os.listdir(pid_path) for json_file in json_files: - shutil.move(pid_path + '/' + json_file, ppid_path) + shutil.move(os.path.join(pid_path, json_file), ppid_path) + def _compile_akg_task_gpu(*json_strs): """ @@ -52,6 +39,7 @@ def _compile_akg_task_gpu(*json_strs): Parameters: json_strs: list. List contains multiple kernel infos, suitable for json compile api. """ + sys.path.insert(0, get_akg_path()) p = __import__("akg", globals(), locals(), ['ms'], 0) func = getattr(p.ms, "compilewithjson") @@ -66,6 +54,7 @@ def _compile_akg_task_gpu(*json_strs): copy_json(pid_path, os.path.realpath("./cuda_meta_" + str(os.getppid()))) shutil.rmtree(pid_path) + def _compile_akg_task_ascend(*json_strs): """ compile func called in single process @@ -76,11 +65,10 @@ def _compile_akg_task_ascend(*json_strs): akg_compiler = os.path.join(os.path.split( os.path.realpath(__file__))[0], "compiler.py") for json_str in json_strs: - res = subprocess.run([sys.executable, akg_compiler, json_str], text=True) - - if res.returncode != 0: - raise ValueError("Failed, args: {}!".format(json_str)) - + try: + subprocess.run([sys.executable, akg_compiler, json_str], text=True, check=True) + except BaseException as e: + logger.error(e, "Failed, args: {}!".format(json_str)) def create_akg_parallel_process(process_num, wait_time, platform=""): @@ -92,6 +80,7 @@ def create_akg_parallel_process(process_num, wait_time, platform=""): """ return AkgProcess(process_num, wait_time, platform) + class AkgProcess: """akg kernel parallel process""" diff --git a/mindspore/_extends/parallel_compile/akg_compiler/compiler.py b/mindspore/_extends/parallel_compile/akg_compiler/compiler.py index 7aae1db2a92..34f5d85ebb1 100644 --- a/mindspore/_extends/parallel_compile/akg_compiler/compiler.py +++ b/mindspore/_extends/parallel_compile/akg_compiler/compiler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +13,8 @@ # limitations under the License. # ============================================================================ """Providing akg compile with json""" -import importlib -import os import sys -def get_akg_path(): - """get akg directory base path""" - search_res = importlib.util.find_spec("mindspore") - if search_res is None: - raise RuntimeError("Cannot find mindspore module!") - - res_path = search_res.origin - find_pos = res_path.find("__init__.py") - if find_pos == -1: - raise RuntimeError("Find module mindspore origin file failed!") - akg_path = "{}_akg".format(res_path[:find_pos]) - if not os.path.isdir(akg_path): - raise RuntimeError("Cannot find akg from mindspore module!") - return akg_path def run_compiler(op_json): """ @@ -43,6 +27,7 @@ def run_compiler(op_json): Returns: None """ + from get_file_path import get_akg_path sys.path.insert(0, get_akg_path()) p = __import__("akg", globals(), locals(), ['ms'], 0) func = getattr(p.ms, "compilewithjson") @@ -50,5 +35,6 @@ def run_compiler(op_json): if not res: raise ValueError("Compile error") + if __name__ == "__main__": run_compiler(sys.argv[1]) diff --git a/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py b/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py new file mode 100644 index 00000000000..98d5fdc39e4 --- /dev/null +++ b/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py @@ -0,0 +1,33 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Providing akg directory base path""" +import importlib +import os + + +def get_akg_path(): + """get akg directory base path""" + search_res = importlib.util.find_spec("mindspore") + if search_res is None: + raise RuntimeError("Cannot find mindspore module!") + + res_path = search_res.origin + find_pos = res_path.find("__init__.py") + if find_pos == -1: + raise RuntimeError("Find module mindspore origin file failed!") + akg_path = "{}_akg".format(res_path[:find_pos]) + if not os.path.isdir(akg_path): + raise RuntimeError("Cannot find akg from mindspore module!") + return akg_path diff --git a/mindspore/_extends/remote/kernel_build_server.py b/mindspore/_extends/remote/kernel_build_server.py index 3599d95a528..965032d372c 100644 --- a/mindspore/_extends/remote/kernel_build_server.py +++ b/mindspore/_extends/remote/kernel_build_server.py @@ -16,22 +16,7 @@ import os from mindspore import log as logger from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process - - -class AkgBuilder: - """Akg building wrapper""" - - def __init__(self): - pass - - def create(self, process_num, waitime, platform=""): - self.akg_builder = create_akg_parallel_process(process_num, waitime, platform) - - def accept_json(self, json): - return self.akg_builder.accept_json(json) - - def compile(self): - return self.akg_builder.compile() +from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single class Messager: @@ -140,5 +125,65 @@ class Messager: raise NotImplementedError +class AkgBuilder(): + """Akg building wrapper""" + + def __init__(self): + pass + + def create(self, process_num, waitime, platform=""): + """ Create akg processor""" + + self.akg_processor = create_akg_parallel_process(process_num, waitime, platform) + + def accept_json(self, json): + """ Accept json""" + + return self.akg_processor.accept_json(json) + + def compile(self): + """Compile""" + + return self.akg_processor.compile() + + def handle(self, messager, arg, platform=""): + """Handle message about akg""" + + if arg == 'AKG/PID': + messager.send_res(os.getpid()) + elif arg == 'AKG/START': + messager.send_ack() + process_num_str = messager.get_message() + messager.send_ack() + wait_time_str = messager.get_message() + self.create(int(process_num_str), int(wait_time_str), platform) + messager.send_ack() + elif arg == 'AKG/DATA': + messager.send_ack() + while True: + req = messager.get_message() + if req.startswith('{'): + self.accept_json(req) + messager.send_ack() + elif req == 'AKG/WAIT': + res = self.compile() + messager.send_res(res) + break + else: + messager.send_ack(False) + break + elif arg == 'AKG/COMPILE': + messager.send_ack() + json = messager.get_message() + try: + akg_compile_single(json) + except ValueError: + messager.send_ack(False) + messager.exit() + finally: + pass + messager.send_ack() + + def get_logger(): return logger diff --git a/mindspore/_extends/remote/kernel_build_server_ascend.py b/mindspore/_extends/remote/kernel_build_server_ascend.py index e084f2acee2..a98842519f9 100644 --- a/mindspore/_extends/remote/kernel_build_server_ascend.py +++ b/mindspore/_extends/remote/kernel_build_server_ascend.py @@ -98,35 +98,6 @@ class AscendMessager(Messager): self.send_ack(False) self.exit() - def akg_handle(self, arg): - """ - Handle arg start with AKG - """ - if arg == 'AKG/START': - self.send_ack() - process_num_str = self.get_message() - self.send_ack() - wait_time_str = self.get_message() - self.akg_builder.create(int(process_num_str), int(wait_time_str), "ASCEND") - self.send_ack() - elif arg == 'AKG/DATA': - self.send_ack() - while True: - req = self.get_message() - if req.startswith('{'): - self.akg_builder.accept_json(req) - self.send_ack() - elif req == 'AKG/WAIT': - res = self.akg_builder.compile() - self.send_res(res) - break - else: - self.send_ack(False) - break - else: - self.send_ack(False) - self.exit() - def handle(self): """ Communicate with remote client. @@ -136,7 +107,7 @@ class AscendMessager(Messager): if arg.startswith('TBE'): self.tbe_handle(arg) elif arg.startswith('AKG'): - self.akg_handle(arg) + self.akg_builder.handle(self, arg, "ASCEND") elif arg == 'FORMAT': self.send_ack() json = self.get_message() diff --git a/mindspore/_extends/remote/kernel_build_server_gpu.py b/mindspore/_extends/remote/kernel_build_server_gpu.py index 80d142a3739..28c1efaa3a9 100644 --- a/mindspore/_extends/remote/kernel_build_server_gpu.py +++ b/mindspore/_extends/remote/kernel_build_server_gpu.py @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================ """kernel build server for gpu""" -import os import sys import warnings from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder -from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single class GpuMessager(Messager): @@ -37,40 +35,8 @@ class GpuMessager(Messager): Reference protocol between them at PR#4063 """ arg = self.get_message() - if arg == 'AKG/PID': - self.send_res(os.getpid()) - elif arg == 'AKG/START': - self.send_ack() - process_num_str = self.get_message() - self.send_ack() - wait_time_str = self.get_message() - self.akg_builder.create(int(process_num_str), int(wait_time_str), "GPU") - self.send_ack() - elif arg == 'AKG/DATA': - self.send_ack() - while True: - req = self.get_message() - if req.startswith('{'): - self.akg_builder.accept_json(req) - self.send_ack() - elif req == 'AKG/WAIT': - res = self.akg_builder.compile() - self.send_res(res) - break - else: - self.send_ack(False) - break - elif arg == 'AKG/COMPILE': - self.send_ack() - json = self.get_message() - try: - akg_compile_single(json) - except ValueError: - self.send_ack(False) - self.exit() - finally: - pass - self.send_ack() + if "AKG" in arg: + self.akg_builder.handle(self, arg, "GPU") else: self.send_ack(False) self.exit()