From be688a92cd4f7efa9ee9fa727db24d5e0a7ad154 Mon Sep 17 00:00:00 2001
From: Emir Haleva <emir.haleva@huawei.com>
Date: Thu, 3 Jun 2021 20:15:41 +0300
Subject: [PATCH] merge compilation of MS Lite Inference and Train

---
 build.sh                                      | 12 ++--
 cmake/package_lite.cmake                      | 68 ++++++++++++-------
 .../cpu/nnacl/infer/tile_infer.c              | 15 ----
 mindspore/lite/CMakeLists.txt                 |  3 +-
 mindspore/lite/examples/train_lenet/Makefile  |  2 +-
 .../lite/examples/transfer_learning/Makefile  |  2 +-
 mindspore/lite/java/native/CMakeLists.txt     |  1 -
 mindspore/lite/minddata/CMakeLists.txt        |  1 +
 mindspore/lite/src/CMakeLists.txt             | 67 +++++++++---------
 .../kernel/arm/fp32/matmul_fp32_base.cc       | 14 +---
 mindspore/lite/test/st/run_net_train.sh       | 19 +++---
 .../arm/fp32/fullconnection_fp32_tests.cc     | 19 +-----
 .../lite/tools/benchmark_train/CMakeLists.txt | 13 ++--
 13 files changed, 106 insertions(+), 130 deletions(-)

diff --git a/build.sh b/build.sh
index 6b6eddf05da..3cc04904854 100755
--- a/build.sh
+++ b/build.sh
@@ -506,11 +506,11 @@ build_lite_x86_64_jni_and_jar()
     tar -zxf ${BASEPATH}/output/tmp/${pkg_name}.tar.gz
     rm -rf ${LITE_JAVA_PATH}/java/linux_x86/libs/   && mkdir -pv ${LITE_JAVA_PATH}/java/linux_x86/libs/
     rm -rf ${LITE_JAVA_PATH}/native/libs/linux_x86/ && mkdir -pv ${LITE_JAVA_PATH}/native/libs/linux_x86/
-    cp ./${pkg_name}/${inference_or_train}/lib/*.so* ${LITE_JAVA_PATH}/java/linux_x86/libs/
-    cp ./${pkg_name}/${inference_or_train}/lib/*.so* ${LITE_JAVA_PATH}/native/libs/linux_x86/
+    cp ./${pkg_name}/inference/lib/*.so* ${LITE_JAVA_PATH}/java/linux_x86/libs/
+    cp ./${pkg_name}/inference/lib/*.so* ${LITE_JAVA_PATH}/native/libs/linux_x86/
     if [ -f "mindspore-lite-${VERSION_STR}-train-linux-x64.tar.gz" ]; then
-        cp ./${pkg_name}/train/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/java/linux_x86/libs/
-        cp ./${pkg_name}/train/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/native/libs/linux_x86/
+        cp ./${pkg_name}/inference/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/java/linux_x86/libs/
+        cp ./${pkg_name}/inference/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/native/libs/linux_x86/
     fi
     # build jni so
     cd ${BASEPATH}/mindspore/lite/build
@@ -525,7 +525,7 @@ build_lite_x86_64_jni_and_jar()
     fi
     cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
     cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
-    cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/${inference_or_train}/lib/
+    cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
 
     # build java common
     cd ${LITE_JAVA_PATH}/java/common
@@ -537,7 +537,7 @@ build_lite_x86_64_jni_and_jar()
     cd ${LITE_JAVA_PATH}/java/linux_x86/
     gradle clean
     gradle releaseJar
-    cp ./build/lib/jar/*.jar ${BASEPATH}/output/tmp/${pkg_name}/${inference_or_train}/lib/
+    cp ./build/lib/jar/*.jar ${BASEPATH}/output/tmp/${pkg_name}/inference/lib/
 
     # package
     cd ${BASEPATH}/output/tmp
diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake
index 2ad6cd4f9e0..439d55ed810 100644
--- a/cmake/package_lite.cmake
+++ b/cmake/package_lite.cmake
@@ -7,27 +7,19 @@ set(CONVERTER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/converter)
 set(OBFUSCATOR_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/obfuscator)
 set(CROPPER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/cropper)
 
-if(SUPPORT_TRAIN)
-    set(RUNTIME_DIR ${RUNTIME_PKG_NAME}/train)
-    set(RUNTIME_INC_DIR ${RUNTIME_PKG_NAME}/train/include)
-    set(RUNTIME_LIB_DIR ${RUNTIME_PKG_NAME}/train/lib)
-    set(MIND_DATA_INC_DIR ${RUNTIME_PKG_NAME}/train/include/dataset)
-    set(TURBO_DIR ${RUNTIME_PKG_NAME}/train/third_party/libjpeg-turbo)
-    set(SECUREC_DIR ${RUNTIME_PKG_NAME}/train/third_party/securec)
-    set(MINDSPORE_LITE_LIB_NAME libmindspore-lite-train)
-    set(BENCHMARK_NAME benchmark_train)
-    set(BENCHMARK_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark_train)
-else()
-    set(RUNTIME_DIR ${RUNTIME_PKG_NAME}/inference)
-    set(RUNTIME_INC_DIR ${RUNTIME_PKG_NAME}/inference/include)
-    set(RUNTIME_LIB_DIR ${RUNTIME_PKG_NAME}/inference/lib)
-    set(MIND_DATA_INC_DIR ${RUNTIME_PKG_NAME}/inference/include/dataset)
-    set(TURBO_DIR ${RUNTIME_PKG_NAME}/inference/third_party/libjpeg-turbo)
-    set(SECUREC_DIR ${RUNTIME_PKG_NAME}/inference/third_party/securec)
-    set(MINDSPORE_LITE_LIB_NAME libmindspore-lite)
-    set(BENCHMARK_NAME benchmark)
-    set(BENCHMARK_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark)
-endif()
+set(RUNTIME_DIR ${RUNTIME_PKG_NAME}/inference)
+set(RUNTIME_INC_DIR ${RUNTIME_PKG_NAME}/inference/include)
+set(RUNTIME_LIB_DIR ${RUNTIME_PKG_NAME}/inference/lib)
+set(MIND_DATA_INC_DIR ${RUNTIME_PKG_NAME}/inference/include/dataset)
+set(TURBO_DIR ${RUNTIME_PKG_NAME}/inference/third_party/libjpeg-turbo)
+set(SECUREC_DIR ${RUNTIME_PKG_NAME}/inference/third_party/securec)
+set(MINDSPORE_LITE_LIB_NAME libmindspore-lite)
+set(BENCHMARK_NAME benchmark)
+set(BENCHMARK_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark)
+
+set(MINDSPORE_LITE_TRAIN_LIB_NAME libmindspore-lite-train)
+set(BENCHMARK_TRAIN_NAME benchmark_train)
+set(BENCHMARK_TRAIN_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark_train)
 
 if(BUILD_MINDDATA STREQUAL "full")
     install(FILES
@@ -45,7 +37,7 @@ if(BUILD_MINDDATA STREQUAL "full")
         file(GLOB JPEGTURBO_LIB_LIST ${jpeg_turbo_LIBPATH}/*.so)
         install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
-        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite_static.a DESTINATION
+        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a
@@ -54,7 +46,7 @@ if(BUILD_MINDDATA STREQUAL "full")
         file(GLOB JPEGTURBO_LIB_LIST ${jpeg_turbo_LIBPATH}/*.so)
         install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
-        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite_static.a DESTINATION
+        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a
@@ -62,7 +54,7 @@ if(BUILD_MINDDATA STREQUAL "full")
     else()
         install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
-        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite_static.a DESTINATION
+        install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION
                 ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.3.0 DESTINATION ${TURBO_DIR}/lib
                 RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME})
@@ -162,6 +154,10 @@ if(PLATFORM_ARM64)
     if(SUPPORT_TRAIN)
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
     else()
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
@@ -194,6 +190,10 @@ if(PLATFORM_ARM64)
     install(TARGETS wrapper ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
     if(MSLITE_ENABLE_TOOLS)
         install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        if(SUPPORT_TRAIN)
+            install(TARGETS ${BENCHMARK_TRAIN_NAME} RUNTIME DESTINATION ${BENCHMARK_TRAIN_ROOT_DIR} COMPONENT
+                    ${RUNTIME_COMPONENT_NAME})
+        endif()
     endif()
 elseif(PLATFORM_ARM32)
     if(SUPPORT_NPU)
@@ -207,6 +207,10 @@ elseif(PLATFORM_ARM32)
     if(SUPPORT_TRAIN)
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
     else()
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
@@ -239,6 +243,10 @@ elseif(PLATFORM_ARM32)
     install(TARGETS wrapper ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
     if(MSLITE_ENABLE_TOOLS)
         install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        if(SUPPORT_TRAIN)
+            install(TARGETS ${BENCHMARK_TRAIN_NAME} RUNTIME DESTINATION ${BENCHMARK_TRAIN_ROOT_DIR} COMPONENT
+                    ${RUNTIME_COMPONENT_NAME})
+        endif()
     endif()
 elseif(WIN32)
     get_filename_component(CXX_DIR ${CMAKE_CXX_COMPILER} PATH)
@@ -279,6 +287,10 @@ elseif(WIN32)
     endif()
     if(MSLITE_ENABLE_TOOLS)
         install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        if(SUPPORT_TRAIN)
+            install(TARGETS ${BENCHMARK_TRAIN_NAME} RUNTIME DESTINATION ${BENCHMARK_TRAIN_ROOT_DIR} COMPONENT
+                    ${RUNTIME_COMPONENT_NAME})
+        endif()
     endif()
     install(FILES ${LIB_LIST} DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
     install(DIRECTORY ${flatbuffers_INC} DESTINATION ${RUNTIME_INC_DIR}/third_party/
@@ -305,6 +317,10 @@ else()
     if(SUPPORT_TRAIN)
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
+                ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
     else()
         install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
                 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
@@ -360,6 +376,10 @@ else()
     endif()
     if(MSLITE_ENABLE_TOOLS)
         install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
+        if(SUPPORT_TRAIN)
+            install(TARGETS ${BENCHMARK_TRAIN_NAME} RUNTIME DESTINATION ${BENCHMARK_TRAIN_ROOT_DIR} COMPONENT
+                    ${RUNTIME_COMPONENT_NAME})
+        endif()
         install(TARGETS cropper RUNTIME DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
         install(FILES ${TOP_DIR}/mindspore/lite/build/tools/cropper/cropper_mapping_cpu.cfg
                 DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tile_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tile_infer.c
index df36be303a6..b287a649507 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tile_infer.c
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tile_infer.c
@@ -74,20 +74,6 @@ int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
     param->multiples_[i] = input1_data[i];
   }
 
-#ifdef SUPPORT_TRAIN
-  const size_t in_dims = input->shape_size_;
-  const size_t delta_dims = in_dims - multiples_size;
-
-  size_t i = 0;
-  for (; i < delta_dims; ++i) {
-    int tmp = input->shape_[i];
-    ShapePush(out_shape, &out_shape_size, tmp);
-  }
-  for (; i < in_dims; ++i) {
-    int tmp = input->shape_[i] * (param->multiples_[i - delta_dims]);
-    ShapePush(out_shape, &out_shape_size, tmp);
-  }
-#else
   int *dims = param->dims_;
   size_t dims_size = param->dims_size_;
   if (dims_size == 0) {
@@ -110,7 +96,6 @@ int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
   }
   // change caffe param format to tflite
   TileParamCaffe2Tflite(param, out_shape_size);
-#endif
   SetShapeArray(output, out_shape, out_shape_size);
   return NNACL_OK;
 }
diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt
index 19d4678d17b..8185ee603a0 100644
--- a/mindspore/lite/CMakeLists.txt
+++ b/mindspore/lite/CMakeLists.txt
@@ -336,9 +336,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/micro/coder)
 if(NOT APPLE AND MSLITE_ENABLE_TOOLS)
     if(SUPPORT_TRAIN)
         add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark_train)
-    else()
-        add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)
     endif()
+    add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)
 endif()
 if(NOT WIN32)
     if(MSLITE_ENABLE_TOOLS)
diff --git a/mindspore/lite/examples/train_lenet/Makefile b/mindspore/lite/examples/train_lenet/Makefile
index a120fac3178..869caa0e14d 100644
--- a/mindspore/lite/examples/train_lenet/Makefile
+++ b/mindspore/lite/examples/train_lenet/Makefile
@@ -1,6 +1,6 @@
 BASE_DIR=$(realpath ../../../../)
 APP:=bin/net_runner
-LMSLIB:=-lmindspore-lite-train
+LMSLIB:=-lmindspore-lite-train -lmindspore-lite
 LMDLIB:=-lminddata-lite
 MSDIR:=$(realpath package-$(TARGET)/lib)
 ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")
diff --git a/mindspore/lite/examples/transfer_learning/Makefile b/mindspore/lite/examples/transfer_learning/Makefile
index f22d0369638..68b5c0ce4a1 100644
--- a/mindspore/lite/examples/transfer_learning/Makefile
+++ b/mindspore/lite/examples/transfer_learning/Makefile
@@ -1,6 +1,6 @@
 BASE_DIR=$(realpath ../../../../)
 APP:=bin/net_runner
-LMSLIB:=-lmindspore-lite-train
+LMSLIB:=-lmindspore-lite-train -lmindspore-lite
 LMDLIB:=-lminddata-lite
 MSDIR:=$(realpath package-$(TARGET)/lib)
 ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")
diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt
index 9ec759013fe..c2399e40d6c 100644
--- a/mindspore/lite/java/native/CMakeLists.txt
+++ b/mindspore/lite/java/native/CMakeLists.txt
@@ -60,7 +60,6 @@ set(JNI_SRC
 set(LITE_SO_NAME mindspore-lite)
 
 if(SUPPORT_TRAIN)
-  set(LITE_SO_NAME mindspore-lite-train)
   set(JNI_SRC
       ${JNI_SRC}
       ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt
index d56ea46e662..9efbb103281 100644
--- a/mindspore/lite/minddata/CMakeLists.txt
+++ b/mindspore/lite/minddata/CMakeLists.txt
@@ -297,6 +297,7 @@ if(BUILD_MINDDATA STREQUAL "full")
             ${CORE_DIR}/utils/ms_utils.cc
             ${MINDDATA_FULL_SRC}
             )
+    set_target_properties(minddata-lite_static PROPERTIES OUTPUT_NAME "minddata-lite")
 
     add_dependencies(minddata-lite fbs_src)
     add_dependencies(minddata-lite_static fbs_src)
diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
index d31be860170..15a2aaec6e9 100644
--- a/mindspore/lite/src/CMakeLists.txt
+++ b/mindspore/lite/src/CMakeLists.txt
@@ -110,33 +110,25 @@ if(MSLITE_GPU_BACKEND STREQUAL cuda)
             ${CUDA_RUNTIME_SRC}
             )
 endif()
-if(SUPPORT_TRAIN)
-    set(ANF_SRC
-            ${ANF_SRC}
+set(TRAIN_SRC
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/lr_scheduler.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_metrics.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
+        )
+if(ENABLE_V0)
+    set(TRAIN_SRC
+            ${TRAIN_SRC}
+            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter_v0.cc
             )
-    set(PASS_SRC)
-    set(LITE_SRC
-            ${LITE_SRC}
-            ${ANF_SRC}
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/lr_scheduler.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_metrics.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
-            ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
-            )
-    if(ENABLE_V0)
-      set(LITE_SRC
-              ${LITE_SRC}
-              ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter_v0.cc
-              )
-    endif()
 endif()
 
 if(ENABLE_MINDRT)
@@ -193,11 +185,19 @@ if(BUILD_MINDDATA STREQUAL "lite")
     target_link_libraries(mindspore-lite minddata_eager_mid minddata-lite)
     target_link_libraries(mindspore-lite_static minddata_eager_mid)
 endif()
+
 if(SUPPORT_TRAIN)
-    add_dependencies(mindspore-lite fbs_inner_src)
-    add_dependencies(mindspore-lite_static fbs_inner_src)
-    target_link_libraries(mindspore-lite minddata-lite)
-    target_link_libraries(mindspore-lite_static minddata-lite)
+  add_library(mindspore-lite-train SHARED ${TRAIN_SRC})
+  set_target_properties(mindspore-lite-train PROPERTIES OUTPUT_NAME "mindspore-lite-train")
+  add_dependencies(mindspore-lite-train fbs_src fbs_inner_src)
+  set_target_properties(mindspore-lite-train PROPERTIES CLEAN_DIRECT_OUTPUT 1)
+  target_link_libraries(mindspore-lite-train minddata-lite mindspore-lite)
+
+  add_library(mindspore-lite-train_static STATIC ${TRAIN_SRC})
+  set_target_properties(mindspore-lite-train_static PROPERTIES OUTPUT_NAME "mindspore-lite-train")
+  add_dependencies(mindspore-lite-train_static fbs_inner_src)
+  set_target_properties(mindspore-lite-train_static PROPERTIES CLEAN_DIRECT_OUTPUT 1)
+  target_link_libraries(mindspore-lite-train_static minddata-lite mindspore-lite)
 endif()
 
 if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
@@ -279,12 +279,7 @@ if(NOT WIN32)
     target_link_libraries(mindspore-lite dl)
 endif()
 
-if(SUPPORT_TRAIN)
-    set_target_properties(mindspore-lite PROPERTIES OUTPUT_NAME "mindspore-lite-train")
-    set_target_properties(mindspore-lite_static PROPERTIES OUTPUT_NAME "mindspore-lite-train")
-endif()
-
 if(ENABLE_MODEL_OBF)
     target_link_libraries(mindspore-lite ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
     target_link_libraries(mindspore-lite_static ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
-endif()
\ No newline at end of file
+endif()
diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc
index 84647e9f9bd..fa5523774da 100644
--- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc
+++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc
@@ -72,15 +72,11 @@ int MatmulFp32BaseCPUKernel::InitBufferA() {
   if (a_pack_ptr_ != nullptr) {
     return RET_OK;
   }
-#ifdef SUPPORT_TRAIN
   if (op_parameter_->is_train_session_) {
     a_pack_ptr_ = reinterpret_cast<float *>(workspace());
   } else {
     a_pack_ptr_ = reinterpret_cast<float *>(context_->allocator->Malloc(matrix_a_pack_size_ * sizeof(float)));
   }
-#else
-  a_pack_ptr_ = reinterpret_cast<float *>(context_->allocator->Malloc(matrix_a_pack_size_ * sizeof(float)));
-#endif
   if (a_pack_ptr_ == nullptr) {
     MS_LOG(ERROR) << "malloc a_pack_ptr_ failed";
     return RET_ERROR;
@@ -92,15 +88,11 @@ int MatmulFp32BaseCPUKernel::InitBufferB() {
   if (b_pack_ptr_ != nullptr) {
     return RET_OK;
   }
-#ifdef SUPPORT_TRAIN
   if (op_parameter_->is_train_session_) {
     b_pack_ptr_ = reinterpret_cast<float *>(workspace()) + matrix_a_pack_size_;
   } else {
     b_pack_ptr_ = reinterpret_cast<float *>(context_->allocator->Malloc(matrix_b_pack_size_ * sizeof(float)));
   }
-#else
-  b_pack_ptr_ = reinterpret_cast<float *>(context_->allocator->Malloc(matrix_b_pack_size_ * sizeof(float)));
-#endif
   if (b_pack_ptr_ == nullptr) {
     MS_LOG(ERROR) << "malloc b_pack_ptr_ failed";
     return RET_ERROR;
@@ -328,9 +320,9 @@ int MatmulFp32BaseCPUKernel::ReSize() {
                   << "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size" << matrix_b_pack_size_;
     return RET_ERROR;
   }
-#ifdef SUPPORT_TRAIN
-  set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * sizeof(float));
-#endif
+  if (op_parameter_->is_train_session_) {
+    set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * sizeof(float));
+  }
 
   if (params_->b_const_ == true && src_b_ != nullptr) {
     if (RET_OK != InitBufferB()) {
diff --git a/mindspore/lite/test/st/run_net_train.sh b/mindspore/lite/test/st/run_net_train.sh
index f54e310b345..7eed84fcc97 100755
--- a/mindspore/lite/test/st/run_net_train.sh
+++ b/mindspore/lite/test/st/run_net_train.sh
@@ -74,7 +74,7 @@ function Run_Converter() {
 # Run on x86 platform:
 function Run_x86() {
     cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1
-    export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./train/lib:./train/third_party/libjpeg-turbo/lib
+    export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./inference/lib:./inference/third_party/libjpeg-turbo/lib
     # Run mindspore converted train models:
     fail=0
     while read line; do
@@ -141,18 +141,19 @@ function Run_arm() {
 
     # If build with minddata, copy the minddata related libs
     cd ${benchmark_train_test_path} || exit 1
-    if [ -f ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/lib/libminddata-lite.so ]; then
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/third_party/libjpeg-turbo/lib/libjpeg.so* ${benchmark_train_test_path}/ || exit 1
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/third_party/libjpeg-turbo/lib/libturbojpeg.so* ${benchmark_train_test_path}/ || exit 1
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/lib/libminddata-lite.so ${benchmark_train_test_path}/libminddata-lite.so || exit 1
+    if [ -f ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/lib/libminddata-lite.so ]; then
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/third_party/libjpeg-turbo/lib/libjpeg.so* ${benchmark_train_test_path}/ || exit 1
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/third_party/libjpeg-turbo/lib/libturbojpeg.so* ${benchmark_train_test_path}/ || exit 1
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/lib/libminddata-lite.so ${benchmark_train_test_path}/libminddata-lite.so || exit 1
     fi
     if [ "$1" == arm64 ] || [ "$1" == arm32 ]; then
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/third_party/hiai_ddk/lib/libhiai.so ${benchmark_train_test_path}/libhiai.so || exit 1
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/third_party/hiai_ddk/lib/libhiai_ir.so ${benchmark_train_test_path}/libhiai_ir.so || exit 1
-        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_train_test_path}/libhiai_ir_build.so || exit 1
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/third_party/hiai_ddk/lib/libhiai.so ${benchmark_train_test_path}/libhiai.so || exit 1
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/third_party/hiai_ddk/lib/libhiai_ir.so ${benchmark_train_test_path}/libhiai_ir.so || exit 1
+        cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/third_party/hiai_ddk/lib/libhiai_ir_build.so ${benchmark_train_test_path}/libhiai_ir_build.so || exit 1
     fi
 
-    cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/train/lib/libmindspore-lite-train.so ${benchmark_train_test_path}/libmindspore-lite-train.so || exit 1
+    cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/lib/libmindspore-lite.so ${benchmark_train_test_path}/libmindspore-lite.so || exit 1
+    cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/inference/lib/libmindspore-lite-train.so ${benchmark_train_test_path}/libmindspore-lite-train.so || exit 1
     cp -a ${arm_path}/mindspore-lite-${version_arm}-train-android-${process_unit}/tools/benchmark_train/benchmark_train ${benchmark_train_test_path}/benchmark_train || exit 1
 
     # adb push all needed files to the phone
diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc
index 25816a8addb..207b9f15b42 100644
--- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc
+++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc
@@ -69,6 +69,7 @@ int FcTestInit1(std::vector<lite::Tensor *> *inputs_, std::vector<lite::Tensor *
   matmal_param->has_bias_ = true;
   matmal_param->act_type_ = ActType_No;
   matmal_param->op_parameter_.type_ = 67;
+  matmal_param->op_parameter_.is_train_session_ = false;
   KernelInferShape(*inputs_, *outputs_, reinterpret_cast<OpParameter *>(matmal_param));
   return out_t->ElementsNum();
 }
@@ -84,15 +85,9 @@ TEST_F(TestFcFp32, FcTest1) {
   ASSERT_EQ(lite::RET_OK, ctx->Init());
   auto *fc = new kernel::FullconnectionCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
   fc->Init();
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::AllocWorkspace(fc->workspace_size());
-#endif
   fc->Run();
   ASSERT_EQ(0, CompareOutputData(reinterpret_cast<float *>(outputs_[0]->MutableData()), correct, total_size, 0.0001));
   delete ctx;
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::FreeWorkspace();
-#endif
 }
 
 int FcTestInit2(std::vector<lite::Tensor *> *inputs_, std::vector<lite::Tensor *> *outputs_,
@@ -149,15 +144,9 @@ TEST_F(TestFcFp32, FcTest2) {
   ASSERT_EQ(lite::RET_OK, ctx->Init());
   auto *fc = new kernel::FullconnectionCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
   fc->Init();
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::AllocWorkspace(fc->workspace_size());
-#endif
   fc->Run();
   ASSERT_EQ(0, CompareOutputData(reinterpret_cast<float *>(outputs_[0]->MutableData()), correct, total_size, 0.0001));
   delete ctx;
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::FreeWorkspace();
-#endif
 }
 
 void FcTestInit3(std::vector<lite::Tensor *> *inputs_, std::vector<lite::Tensor *> *outputs_,
@@ -204,18 +193,12 @@ TEST_F(TestFcFp32, FcTest3) {
   ASSERT_EQ(lite::RET_OK, ctx->Init());
   auto *fc = new kernel::FullconnectionCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
   fc->Init();
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::AllocWorkspace(fc->workspace_size());
-#endif
   struct timeval start, end;
   gettimeofday(&start, nullptr);
   for (int i = 0; i < 100000; ++i) fc->Run();
   gettimeofday(&end, nullptr);
   // printf("## elapsed: %llu\n", 1000000 * (end.tv_sec - start.tv_sec) + end.tv_usec - end.tv_usec);
   delete ctx;
-#ifdef SUPPORT_TRAIN
-  mindspore::kernel::InnerKernel::FreeWorkspace();
-#endif
 }
 
 }  // namespace mindspore
diff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
index cbef9f3f655..aad0a05508f 100644
--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt
+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
@@ -10,21 +10,22 @@ add_executable(benchmark_train
         ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
         ${COMMON_SRC})
 if(WIN32)
-    add_dependencies(benchmark_train fbs_src mindspore-lite_static)
+    add_dependencies(benchmark_train fbs_src mindspore-lite_static mindspore-lite-train_static)
 else()
-    add_dependencies(benchmark_train fbs_src)
+    add_dependencies(benchmark_train fbs_src mindspore-lite_static mindspore-lite-train_static)
 endif()
 
 if(PLATFORM_ARM32 OR PLATFORM_ARM64)
     if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static")
-        target_link_libraries(benchmark_train mindspore-lite minddata-lite c++_shared)
+        target_link_libraries(benchmark_train mindspore-lite minddata-lite mindspore-lite-train c++_shared)
     else()
-        target_link_libraries(benchmark_train mindspore-lite minddata-lite)
+        target_link_libraries(benchmark_train mindspore-lite minddata-lite mindspore-lite-train)
     endif()
 else()
     if(WIN32)
-        target_link_libraries(benchmark_train mindspore-lite_static pthread cpu_kernel_mid nnacl_mid minddata-lite)
+        target_link_libraries(benchmark_train mindspore-lite_static mindspore-lite-train_static pthread cpu_kernel_mid
+                              nnacl_mid minddata-lite)
     else()
-        target_link_libraries(benchmark_train mindspore-lite pthread minddata-lite)
+        target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread minddata-lite)
     endif()
 endif()