[MSLITE] modify opencl cmake to support cross-platform like windows.

This commit is contained in:
Zhu Guodong 2022-12-15 17:22:39 +08:00
parent 2a48d461b8
commit 5847b11963
2 changed files with 7 additions and 18 deletions

View File

@ -33,23 +33,12 @@ function(gene_opencl CL_SRC_DIR)
endif() endif()
file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl) file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl)
foreach(file_path ${CL_LIST}) foreach(file_path ${CL_LIST})
file(REMOVE ${file_path}.inc) set(out_file_path "${file_path}.inc")
file(REMOVE ${out_file_path})
string(REGEX REPLACE ".+/(.+)\\..*" "\\1" kernel_name "${file_path}") string(REGEX REPLACE ".+/(.+)\\..*" "\\1" kernel_name "${file_path}")
set(inc_file_ex "${file_path}.inc") file(READ ${file_path} cl_program)
execute_process( string(CONCAT cl_str "static const std::string ${kernel_name}_source = R\"(\n" "${cl_program}" ")\";")
COMMAND bash -c "sed 's/\\\\/\\\\\\\\/g' " file(WRITE ${out_file_path} "${cl_str}")
COMMAND bash -c "sed 's/\\\"/\\\\\\\"/g' "
COMMAND bash -c "sed 's/$/\\\\n\\\" \\\\/' "
COMMAND bash -c "sed 's/^/\\\"/' "
WORKING_DIRECTORY ${CL_SRC_DIR}
INPUT_FILE ${file_path}
OUTPUT_FILE ${inc_file_ex}
RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL "0")
message(FATAL_ERROR "error! when generate ${inc_file_ex}")
endif()
__exec_cmd(COMMAND sed -i "1i\\static const char *${kernel_name}_source =\\\"\\\\n\\\" \\\\"
${inc_file_ex} WORKING_DIRECTORY ${CL_SRC_DIR})
__exec_cmd(COMMAND sed -i "$a\\\\\;" ${inc_file_ex} WORKING_DIRECTORY ${CL_SRC_DIR})
endforeach() endforeach()
endfunction() endfunction()

View File

@ -43,7 +43,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state, const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
float *buffer[7], const LstmParameter *lstm_param); float *buffer[C6NUM], const LstmParameter *lstm_param);
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7], const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7],