set soc version

This commit is contained in:
jjfeing 2020-09-28 15:09:47 +08:00
parent 1c06d6e024
commit 7dda95d247
5 changed files with 41 additions and 23 deletions

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""tbe common"""
import json
import os
class TBEException(Exception):
@ -27,23 +26,6 @@ class TBEException(Exception):
return self.__error_msg
def get_ddk_version():
"""get ddk version"""
ddk_version = os.environ.get("DDK_VERSION")
if ddk_version is None:
default_ddk_info_file = '/usr/local/HiAI/runtime/ddk_info'
backup_ddk_info_file = '/usr/local/Ascend/fwkacllib/ddk_info'
if os.path.exists(default_ddk_info_file):
with open(default_ddk_info_file, "r") as fp:
ddk_version = json.load(fp)["VERSION"]
elif os.path.exists(backup_ddk_info_file):
with open(backup_ddk_info_file, "r") as fp:
ddk_version = json.load(fp)["VERSION"]
else:
ddk_version = "Ascend910"
return ddk_version
def get_build_in_impl_path():
"""get build-in tbe implement path"""
tbe_impl_path = os.environ.get("TBE_IMPL_PATH")

View File

@ -18,9 +18,8 @@ import os
import sys
from te.platform.cce_conf import te_set_version
from te.platform.fusion_util import fusion_op
from common import check_kernel_info, get_args, get_build_in_impl_path, get_ddk_version
from common import check_kernel_info, get_args, get_build_in_impl_path
ddk_version = get_ddk_version()
build_in_impl_path = get_build_in_impl_path()
# op function list
@ -30,7 +29,6 @@ fusion_pattern_end_flag = "fusion_pattern_end"
def _initialize(impl_path):
"""Initialize"""
te_set_version(ddk_version)
if impl_path == "":
op_module_name = build_in_impl_path
else:
@ -53,7 +51,7 @@ def build_op(build_type, json_str):
"""
kernel_info = json.loads(json_str)
check_kernel_info(kernel_info)
te_set_version(kernel_info["op_info"]["socVersion"])
op_name = kernel_info['op_info']['name']
try:
@ -111,7 +109,7 @@ def compile_fusion_op(json_str):
Exception: If specific keyword is not found.
"""
args = json.loads(json_str)
te_set_version(ddk_version)
te_set_version(args['fusion_op']["socVersion"])
if 'fusion_op' not in args or not args['fusion_op']:
raise ValueError("Json string Errors, key:fusion_op not found.")
fusion_op_arg = args['fusion_op']

View File

@ -25,6 +25,7 @@
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "utils/ms_context.h"
#include "runtime/dev.h"
namespace mindspore {
namespace kernel {
@ -86,6 +87,8 @@ constexpr auto kJPyModulePath = "py_module_path";
constexpr auto kJPreBuildOutsAttrs = "prebuild_outs_attrs";
constexpr auto kJKwdArgs = "kwds_args";
constexpr auto kJListArgs = "list_args";
constexpr auto kJSocVersion = "socVersion";
constexpr auto kSOC_VERSION = "SOC_VERSION";
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
nlohmann::json *kernel_json) {
@ -122,6 +125,8 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor
nlohmann::json attrs_json;
(void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
op_info_json[kJAttrs] = attrs_json;
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
op_info_json[kJSocVersion] = soc_version;
std::string json_str = op_info_json.dump();
size_t hash_id = std::hash<std::string>()(json_str);
auto context_ptr = MsContext::GetInstance();
@ -414,6 +419,30 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no
return true;
}
string TbeKernelJsonCreator::GetSocVersion() {
// Get default soc version.
const int kSocVersionLen = 50;
char soc_version[kSocVersionLen] = {0};
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
}
MS_LOG(INFO) << "Default SocVersion is " << soc_version;
// Get soc version from env value.
const char *soc_version_env = getenv(kSOC_VERSION);
if (soc_version_env != nullptr) {
if (std::strcmp(soc_version, soc_version_env) != 0) {
MS_LOG(WARNING) << "SocVerison change to " << soc_version_env;
ret = rtSetSocVersion(soc_version_env);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "SetSocVersion to " << soc_version_env << " failed, errorno: " << ret;
}
return soc_version_env;
}
}
return soc_version;
}
void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value,
nlohmann::json *attr_obj) {
MS_EXCEPTION_IF_NULL(value);
@ -630,6 +659,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const std::vector<mindspore::AnfNodePtr>
index = 0;
data_list.insert(data_list.end(), compute_list.begin(), compute_list.end());
(*fusion_json)[kFusionOpList] = data_list;
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
(*fusion_json)[kJSocVersion] = soc_version;
return true;
}
@ -859,6 +890,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::Anf
(*data_str)[kJName] = name;
nlohmann::json output_desc;
output_desc[kJName] = name;
output_desc[kJDataType] = 0;
output_desc[kJShape] = "NULL";
output_desc_list.push_back(output_desc);
(*index)++;
@ -991,6 +1023,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
for (size_t i = 0; i < optional_num; ++i) {
nlohmann::json optional_input_desc;
optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
optional_input_desc[kJShape] = "NULL";
(*index)++;
(*layer_iter)->emplace_back(nullptr);
input_desc_list_tmp.emplace_back(optional_input_desc);

View File

@ -92,6 +92,7 @@ class TbeKernelJsonCreator {
std::string json_name() { return json_name_; }
bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *attrs_json);
static string GetSocVersion();
private:
bool GenTbeInputsJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,

View File

@ -53,6 +53,10 @@ rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char *
return RT_ERROR_NONE;
}
RTS_API rtError_t rtSetSocVersion(const char *version) { return RT_ERROR_NONE; }
rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) { return RT_ERROR_NONE; }
rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, rtSmDesc_t *smDesc,
rtStream_t stream) {
return RT_ERROR_NONE;