forked from mindspore-Ecosystem/mindspore
!17288 fix staic warning in master which fixed in r1.2
Merge pull request !17288 from peiwenfang/fix_static_check_master
This commit is contained in:
commit
7b11a5e451
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,22 +18,8 @@ import shutil
|
|||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import importlib
|
||||
|
||||
def get_akg_path():
|
||||
"""get akg directory base path"""
|
||||
search_res = importlib.util.find_spec("mindspore")
|
||||
if search_res is None:
|
||||
raise RuntimeError("Cannot find mindspore module!")
|
||||
|
||||
res_path = search_res.origin
|
||||
find_pos = res_path.find("__init__.py")
|
||||
if find_pos == -1:
|
||||
raise RuntimeError("Find module mindspore origin file failed!")
|
||||
akg_path = "{}_akg".format(res_path[:find_pos])
|
||||
if not os.path.isdir(akg_path):
|
||||
raise RuntimeError("Cannot find akg from mindspore module!")
|
||||
return akg_path
|
||||
from mindspore import log as logger
|
||||
from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_akg_path
|
||||
|
||||
def copy_json(pid_path, ppid_path):
|
||||
"""
|
||||
|
@ -43,7 +29,8 @@ def copy_json(pid_path, ppid_path):
|
|||
os.mkdir(ppid_path)
|
||||
json_files = os.listdir(pid_path)
|
||||
for json_file in json_files:
|
||||
shutil.move(pid_path + '/' + json_file, ppid_path)
|
||||
shutil.move(os.path.join(pid_path, json_file), ppid_path)
|
||||
|
||||
|
||||
def _compile_akg_task_gpu(*json_strs):
|
||||
"""
|
||||
|
@ -52,6 +39,7 @@ def _compile_akg_task_gpu(*json_strs):
|
|||
Parameters:
|
||||
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
||||
"""
|
||||
|
||||
sys.path.insert(0, get_akg_path())
|
||||
p = __import__("akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
|
@ -66,6 +54,7 @@ def _compile_akg_task_gpu(*json_strs):
|
|||
copy_json(pid_path, os.path.realpath("./cuda_meta_" + str(os.getppid())))
|
||||
shutil.rmtree(pid_path)
|
||||
|
||||
|
||||
def _compile_akg_task_ascend(*json_strs):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
@ -76,11 +65,10 @@ def _compile_akg_task_ascend(*json_strs):
|
|||
akg_compiler = os.path.join(os.path.split(
|
||||
os.path.realpath(__file__))[0], "compiler.py")
|
||||
for json_str in json_strs:
|
||||
res = subprocess.run([sys.executable, akg_compiler, json_str], text=True)
|
||||
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Failed, args: {}!".format(json_str))
|
||||
|
||||
try:
|
||||
subprocess.run([sys.executable, akg_compiler, json_str], text=True, check=True)
|
||||
except BaseException as e:
|
||||
logger.error(e, "Failed, args: {}!".format(json_str))
|
||||
|
||||
|
||||
def create_akg_parallel_process(process_num, wait_time, platform=""):
|
||||
|
@ -92,6 +80,7 @@ def create_akg_parallel_process(process_num, wait_time, platform=""):
|
|||
"""
|
||||
return AkgProcess(process_num, wait_time, platform)
|
||||
|
||||
|
||||
class AkgProcess:
|
||||
"""akg kernel parallel process"""
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,24 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing akg compile with json"""
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
def get_akg_path():
|
||||
"""get akg directory base path"""
|
||||
search_res = importlib.util.find_spec("mindspore")
|
||||
if search_res is None:
|
||||
raise RuntimeError("Cannot find mindspore module!")
|
||||
|
||||
res_path = search_res.origin
|
||||
find_pos = res_path.find("__init__.py")
|
||||
if find_pos == -1:
|
||||
raise RuntimeError("Find module mindspore origin file failed!")
|
||||
akg_path = "{}_akg".format(res_path[:find_pos])
|
||||
if not os.path.isdir(akg_path):
|
||||
raise RuntimeError("Cannot find akg from mindspore module!")
|
||||
return akg_path
|
||||
|
||||
def run_compiler(op_json):
|
||||
"""
|
||||
|
@ -43,6 +27,7 @@ def run_compiler(op_json):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
from get_file_path import get_akg_path
|
||||
sys.path.insert(0, get_akg_path())
|
||||
p = __import__("akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
|
@ -50,5 +35,6 @@ def run_compiler(op_json):
|
|||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_compiler(sys.argv[1])
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing akg directory base path"""
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
def get_akg_path():
|
||||
"""get akg directory base path"""
|
||||
search_res = importlib.util.find_spec("mindspore")
|
||||
if search_res is None:
|
||||
raise RuntimeError("Cannot find mindspore module!")
|
||||
|
||||
res_path = search_res.origin
|
||||
find_pos = res_path.find("__init__.py")
|
||||
if find_pos == -1:
|
||||
raise RuntimeError("Find module mindspore origin file failed!")
|
||||
akg_path = "{}_akg".format(res_path[:find_pos])
|
||||
if not os.path.isdir(akg_path):
|
||||
raise RuntimeError("Cannot find akg from mindspore module!")
|
||||
return akg_path
|
|
@ -16,22 +16,7 @@
|
|||
import os
|
||||
from mindspore import log as logger
|
||||
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
|
||||
|
||||
|
||||
class AkgBuilder:
|
||||
"""Akg building wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, process_num, waitime, platform=""):
|
||||
self.akg_builder = create_akg_parallel_process(process_num, waitime, platform)
|
||||
|
||||
def accept_json(self, json):
|
||||
return self.akg_builder.accept_json(json)
|
||||
|
||||
def compile(self):
|
||||
return self.akg_builder.compile()
|
||||
from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single
|
||||
|
||||
|
||||
class Messager:
|
||||
|
@ -140,5 +125,65 @@ class Messager:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class AkgBuilder():
|
||||
"""Akg building wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, process_num, waitime, platform=""):
|
||||
""" Create akg processor"""
|
||||
|
||||
self.akg_processor = create_akg_parallel_process(process_num, waitime, platform)
|
||||
|
||||
def accept_json(self, json):
|
||||
""" Accept json"""
|
||||
|
||||
return self.akg_processor.accept_json(json)
|
||||
|
||||
def compile(self):
|
||||
"""Compile"""
|
||||
|
||||
return self.akg_processor.compile()
|
||||
|
||||
def handle(self, messager, arg, platform=""):
|
||||
"""Handle message about akg"""
|
||||
|
||||
if arg == 'AKG/PID':
|
||||
messager.send_res(os.getpid())
|
||||
elif arg == 'AKG/START':
|
||||
messager.send_ack()
|
||||
process_num_str = messager.get_message()
|
||||
messager.send_ack()
|
||||
wait_time_str = messager.get_message()
|
||||
self.create(int(process_num_str), int(wait_time_str), platform)
|
||||
messager.send_ack()
|
||||
elif arg == 'AKG/DATA':
|
||||
messager.send_ack()
|
||||
while True:
|
||||
req = messager.get_message()
|
||||
if req.startswith('{'):
|
||||
self.accept_json(req)
|
||||
messager.send_ack()
|
||||
elif req == 'AKG/WAIT':
|
||||
res = self.compile()
|
||||
messager.send_res(res)
|
||||
break
|
||||
else:
|
||||
messager.send_ack(False)
|
||||
break
|
||||
elif arg == 'AKG/COMPILE':
|
||||
messager.send_ack()
|
||||
json = messager.get_message()
|
||||
try:
|
||||
akg_compile_single(json)
|
||||
except ValueError:
|
||||
messager.send_ack(False)
|
||||
messager.exit()
|
||||
finally:
|
||||
pass
|
||||
messager.send_ack()
|
||||
|
||||
|
||||
def get_logger():
|
||||
return logger
|
||||
|
|
|
@ -98,35 +98,6 @@ class AscendMessager(Messager):
|
|||
self.send_ack(False)
|
||||
self.exit()
|
||||
|
||||
def akg_handle(self, arg):
|
||||
"""
|
||||
Handle arg start with AKG
|
||||
"""
|
||||
if arg == 'AKG/START':
|
||||
self.send_ack()
|
||||
process_num_str = self.get_message()
|
||||
self.send_ack()
|
||||
wait_time_str = self.get_message()
|
||||
self.akg_builder.create(int(process_num_str), int(wait_time_str), "ASCEND")
|
||||
self.send_ack()
|
||||
elif arg == 'AKG/DATA':
|
||||
self.send_ack()
|
||||
while True:
|
||||
req = self.get_message()
|
||||
if req.startswith('{'):
|
||||
self.akg_builder.accept_json(req)
|
||||
self.send_ack()
|
||||
elif req == 'AKG/WAIT':
|
||||
res = self.akg_builder.compile()
|
||||
self.send_res(res)
|
||||
break
|
||||
else:
|
||||
self.send_ack(False)
|
||||
break
|
||||
else:
|
||||
self.send_ack(False)
|
||||
self.exit()
|
||||
|
||||
def handle(self):
|
||||
"""
|
||||
Communicate with remote client.
|
||||
|
@ -136,7 +107,7 @@ class AscendMessager(Messager):
|
|||
if arg.startswith('TBE'):
|
||||
self.tbe_handle(arg)
|
||||
elif arg.startswith('AKG'):
|
||||
self.akg_handle(arg)
|
||||
self.akg_builder.handle(self, arg, "ASCEND")
|
||||
elif arg == 'FORMAT':
|
||||
self.send_ack()
|
||||
json = self.get_message()
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""kernel build server for gpu"""
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
|
||||
from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single
|
||||
|
||||
|
||||
class GpuMessager(Messager):
|
||||
|
@ -37,40 +35,8 @@ class GpuMessager(Messager):
|
|||
Reference protocol between them at PR#4063
|
||||
"""
|
||||
arg = self.get_message()
|
||||
if arg == 'AKG/PID':
|
||||
self.send_res(os.getpid())
|
||||
elif arg == 'AKG/START':
|
||||
self.send_ack()
|
||||
process_num_str = self.get_message()
|
||||
self.send_ack()
|
||||
wait_time_str = self.get_message()
|
||||
self.akg_builder.create(int(process_num_str), int(wait_time_str), "GPU")
|
||||
self.send_ack()
|
||||
elif arg == 'AKG/DATA':
|
||||
self.send_ack()
|
||||
while True:
|
||||
req = self.get_message()
|
||||
if req.startswith('{'):
|
||||
self.akg_builder.accept_json(req)
|
||||
self.send_ack()
|
||||
elif req == 'AKG/WAIT':
|
||||
res = self.akg_builder.compile()
|
||||
self.send_res(res)
|
||||
break
|
||||
else:
|
||||
self.send_ack(False)
|
||||
break
|
||||
elif arg == 'AKG/COMPILE':
|
||||
self.send_ack()
|
||||
json = self.get_message()
|
||||
try:
|
||||
akg_compile_single(json)
|
||||
except ValueError:
|
||||
self.send_ack(False)
|
||||
self.exit()
|
||||
finally:
|
||||
pass
|
||||
self.send_ack()
|
||||
if "AKG" in arg:
|
||||
self.akg_builder.handle(self, arg, "GPU")
|
||||
else:
|
||||
self.send_ack(False)
|
||||
self.exit()
|
||||
|
|
Loading…
Reference in New Issue