diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
index c7fc328ffb3..cedcb57cc9e 100644
--- a/mindspore/lite/src/CMakeLists.txt
+++ b/mindspore/lite/src/CMakeLists.txt
@@ -49,6 +49,7 @@ set(LITE_SRC
         ${CMAKE_CURRENT_SOURCE_DIR}/common/prim_util.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/common/tensor_util.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/common/loader_util.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.c
@@ -124,6 +125,7 @@ if(SUPPORT_TRAIN)
             ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
             ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
             ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
+            ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
             )
     if(ENABLE_V0)
       set(LITE_SRC
@@ -192,7 +194,10 @@ if(BUILD_MINDDATA STREQUAL "lite")
     target_link_libraries(mindspore-lite_static minddata_eager_mid)
 endif()
 if(SUPPORT_TRAIN)
+    add_dependencies(mindspore-lite fbs_inner_src)
+    add_dependencies(mindspore-lite_static fbs_inner_src)
     target_link_libraries(mindspore-lite minddata-lite)
+    target_link_libraries(mindspore-lite_static minddata-lite)
 endif()
 
 
diff --git a/mindspore/lite/src/common/quant_utils.cc b/mindspore/lite/src/common/quant_utils.cc
new file mode 100644
index 00000000000..30a1a11e8ec
--- /dev/null
+++ b/mindspore/lite/src/common/quant_utils.cc
@@ -0,0 +1,104 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "schema/inner/model_generated.h"
+#include "src/common/quant_utils.h"
+#include "src/lite_kernel.h"
+
+namespace mindspore {
+namespace lite {
+
+void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
+                         bool channel_at_first, float *desired_max, float *desired_min) {
+  float min = FLT_MAX;
+  float max = -FLT_MAX;
+  // find min and max
+  for (int j = 0; j < one_filter_size; j++) {
+    auto index = j + i * one_filter_size;
+    if (!channel_at_first) {
+      index = j * channels + i;
+    }
+    if (index >= elem_count) {
+      MS_LOG(ERROR) << "over flow!";
+    }
+    min = std::min(min, raw_datas[index]);
+    max = std::max(max, raw_datas[index]);
+  }
+  *desired_max = max;
+  *desired_min = min;
+}
+
+STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
+                             int quant_min, int num_bits) {
+  MS_ASSERT(quantParam != nullptr);
+  if (mMin > 0.0f) {
+    MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
+    mMin = 0.0f;
+  }
+  if (mMax < 0.0f) {
+    MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
+    mMax = 0.0f;
+  }
+  if (mMin > mMax) {
+    MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
+    return RET_PARAM_INVALID;
+  }
+  if (mMin == mMax) {
+    if (mMin != 0.0f) {
+      MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
+      return RET_ERROR;
+    }
+    quantParam->inited = true;
+    quantParam->min = mMin;
+    quantParam->max = mMax;
+    quantParam->scale = 0.0f;
+    quantParam->zeroPoint = 0;
+    quantParam->narrowRange = narrowRange;
+    quantParam->numBits = num_bits;
+    return RET_OK;
+  }
+
+  auto quantMinFloat = static_cast<double>(quant_min);
+  auto quantMaxFloat = static_cast<double>(quant_max);
+  if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
+    MS_LOG(ERROR) << "divisor cannot be 0";
+    return RET_ERROR;
+  }
+  double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
+  if (fabs(scale) <= 0.0f) {
+    MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
+    return RET_ERROR;
+  }
+  const double zeroPointFromMin = quantMinFloat - mMin / scale;
+  int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
+
+  // The zero point should always be in the range of quantized value,
+  // [qmin, qmax].
+  MS_ASSERT(zeroPoint >= quant_min);
+  MS_ASSERT(zeroPoint <= quant_max);
+  quantParam->inited = true;
+  quantParam->min = mMin;
+  quantParam->max = mMax;
+  quantParam->scale = scale;
+  quantParam->zeroPoint = zeroPoint;
+  quantParam->narrowRange = narrowRange;
+  quantParam->numBits = num_bits;
+
+  return RET_OK;
+}
+
+}  // namespace lite
+}  // namespace mindspore
diff --git a/mindspore/lite/src/common/quant_utils.h b/mindspore/lite/src/common/quant_utils.h
new file mode 100644
index 00000000000..5aa49a31b0b
--- /dev/null
+++ b/mindspore/lite/src/common/quant_utils.h
@@ -0,0 +1,234 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
+#define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
+
+#include <float.h>
+#include <cmath>
+#include <climits>
+#include <limits>
+#include <algorithm>
+#include <vector>
+#include "include/errorcode.h"
+#include "src/common/log_adapter.h"
+#include "ir/dtype/type_id.h"
+
+namespace mindspore {
+
+namespace schema {
+struct QuantParamT;
+}
+
+namespace lite {
+const int RET_QUANT_CONTINUE = 2;
+static constexpr double SCALE_THREASHOLD = 1e-38;
+
+static constexpr int kPerTensor = 1;
+
+inline int QuantMax(int bits, TypeId type) {
+  if (type == kNumberTypeInt8) {
+    return (1 << (bits - 1)) - 1;
+  } else if (type == kNumberTypeUInt8) {
+    return (1 << bits) - 1;
+  }
+  return 0;
+}
+
+inline int QuantMin(int bits, TypeId type) {
+  if (type == kNumberTypeInt8) {
+    return -(1 << (bits - 1));
+  }
+  return 0;
+}
+
+void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
+                         bool channel_at_first, float *desired_max, float *desired_min);
+
+STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
+                             int quant_min, int num_bits);
+
+template <typename T>
+T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
+  MS_ASSERT(quantParam != nullptr);
+  MS_ASSERT(quantParam->inited);
+  const auto scale = quantParam->scale;
+  const auto zeroPoint = quantParam->zeroPoint;
+  const auto numBit = quantParam->numBits;
+  const auto narrowRange = quantParam->narrowRange;
+  double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
+  const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
+  double minLimit;
+  if (narrowRange) {
+    minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
+  } else {
+    minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
+  }
+
+  return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
+    double tmp;
+    if (originData > maxLimit) {
+      tmp = maxLimit;
+    } else if (originData < minLimit) {
+      tmp = minLimit;
+    } else {
+      tmp = originData;
+    }
+    auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
+    return quantData;
+  }();
+}
+
+template <typename T>
+T QuantizeData(float originData, const schema::QuantParamT *quantParam, int quant_max, int quant_min) {
+  MS_ASSERT(quantParam != nullptr);
+  MS_ASSERT(quantParam->inited);
+  const auto scale = quantParam->scale;
+  const int zeroPoint = quantParam->zeroPoint;
+  const int maxLimit = quant_max;
+  const int minLimit = quant_min;
+
+  if (scale <= SCALE_THREASHOLD) {
+    return 0;
+  }
+
+  return [maxLimit, minLimit, zeroPoint, scale, originData] {
+    auto quant_data = std::round(originData / scale + zeroPoint);
+    if (quant_data > maxLimit) {
+      quant_data = maxLimit;
+    } else if (quant_data < minLimit) {
+      quant_data = minLimit;
+    }
+    return static_cast<T>(quant_data);
+  }();
+}
+
+template <typename T>
+STATUS DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
+                       const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means,
+                       std::vector<T> *quant_datas) {
+  float min = FLT_MAX;
+  float max = -FLT_MIN;
+  for (uint32_t i = 0; i < elem_count; i++) {
+    min = std::min(min, raw_datas[i]);
+    max = std::max(max, raw_datas[i]);
+  }
+
+  schema::QuantParamT quant_param;
+  if (!k_means) {
+    STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
+    if (status != RET_OK) {
+      MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
+      return status;
+    }
+  }
+  quant_params->emplace_back(quant_param);
+  // update data and datatype
+  for (uint32_t i = 0; i < elem_count; i++) {
+    float raw_data = raw_datas[i];
+    if (!k_means) {
+      auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
+      (*quant_datas)[i] = quant_data;
+    }
+  }
+  return RET_OK;
+}
+
+template <typename T>
+STATUS DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::QuantType &quant_type,
+                         std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
+                         const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, int channels,
+                         bool channel_at_first = true) {
+  static const int quant_param_size = 32 * 8;
+  std::vector<float> dequant_datas(quant_datas->size());
+  if (channels <= 0) {
+    MS_LOG(ERROR) << "channels must be greater than 0";
+    return RET_ERROR;
+  }
+  size_t one_filter_size = elem_count / channels;
+  bool do_quant = quant_param_size / (sizeof(float) * 8 - bit_num) < one_filter_size;
+  if (!do_quant && quant_type == schema::QuantType_WeightQuant) {
+    MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << one_filter_size;
+    return RET_QUANT_CONTINUE;
+  }
+  for (int i = 0; i < channels; i++) {
+    float min = FLT_MAX;
+    float max = -FLT_MAX;
+    GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
+    schema::QuantParamT quant_param;
+    STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
+    if (status != RET_OK) {
+      MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
+      return status;
+    }
+    // do quantization
+    double average_dequant = 0;
+    double average_raw = 0;
+    for (uint32_t j = 0; j < one_filter_size; j++) {
+      auto index = j + i * one_filter_size;
+      if (!channel_at_first) {
+        index = j * channels + i;
+      }
+      MS_ASSERT(index < elem_count);
+      float raw_data = raw_datas[index];
+      auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
+      (*quant_datas)[index] = quant_data;
+
+      if (quant_type == schema::QuantType_WeightQuant) {
+        float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
+        dequant_datas[index] = dequant_data;
+        average_dequant += dequant_data;
+        average_raw += raw_data;
+      }
+    }
+    if (quant_type == schema::QuantType_WeightQuant && !k_means) {
+      // mean
+      average_dequant = average_dequant / one_filter_size;
+      average_raw = average_raw / one_filter_size;
+      // std
+      double variance_dequant = 0;
+      double variance_raw = 0;
+      for (uint32_t j = 0; j < one_filter_size; j++) {
+        auto index = j + i * one_filter_size;
+        if (!channel_at_first) {
+          index = j * channels + i;
+        }
+        MS_ASSERT(index < elem_count);
+        variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
+        variance_raw += std::pow(raw_datas[index] - average_raw, 2);
+      }
+      variance_dequant = std::sqrt(variance_dequant / one_filter_size);
+      variance_raw = std::sqrt(variance_raw / one_filter_size);
+      quant_param.varCorr = 1;
+      if (variance_raw != 0 && variance_dequant != 0) {
+        auto temp_var_corr = variance_raw / variance_dequant;
+        if (temp_var_corr > 0 && temp_var_corr < 10) {
+          quant_param.varCorr = temp_var_corr;
+        } else {
+          MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
+        }
+      }
+      quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
+    }
+    quant_params->emplace_back(quant_param);
+  }
+  return RET_OK;
+}
+
+}  // namespace lite
+}  // namespace mindspore
+
+#endif  // MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc
index 5f7761d705c..0d52844804c 100644
--- a/mindspore/lite/src/train/train_export.cc
+++ b/mindspore/lite/src/train/train_export.cc
@@ -23,46 +23,93 @@
 #include <set>
 #include "schema/inner/model_generated.h"
 #include "src/train/train_utils.h"
+#include "src/common/quant_utils.h"
+#include "tools/common/storage.h"
 
 namespace mindspore {
 namespace lite {
 
-std::vector<uint8_t> TrainExport::CreateData(const mindspore::lite::Tensor *tensor) {
+std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
   uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c());
   auto size = tensor->Size();
   std::vector<uint8_t> data(tensor_data, tensor_data + size);
   return data;
 }
 
+bool TrainExport::NeedQuantization(const lite::Tensor *tensor) {
+  return (tensor->quant_params().size() > 0 && tensor->quant_params().at(0).inited);
+}
+
+schema::QuantType TrainExport::GetNodeQuantType(const kernel::LiteKernel *kernel) {
+  if (std::any_of(kernel->in_tensors().cbegin(), kernel->in_tensors().cend(), [](const lite::Tensor *t) {
+        return (t->IsConst() && (t->quant_params().size() > 0) && (t->quant_params().at(0).inited));
+      })) {
+    return schema::QuantType_QUANT_WEIGHT;
+  }
+  return schema::QuantType_QUANT_NONE;
+}
+
+int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tensor *src_tensor) {
+  int channels = src_tensor->quant_params().size();
+  if (channels < 1) {
+    MS_LOG(ERROR) << "Quant Params is empty";
+    return RET_ERROR;
+  }
+  int bit_num = src_tensor->quant_params().at(0).bitNum;
+  int quant_max = QuantMax(bit_num, kNumberTypeInt8);
+  int quant_min = QuantMin(bit_num, kNumberTypeInt8);
+  std::vector<int8_t> data(src_tensor->ElementsNum());
+  std::vector<schema::QuantParamT> quant_params;
+
+  STATUS ret = RET_OK;
+  if (channels == kPerTensor) {
+    ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data_c()), src_tensor->ElementsNum(),
+                                  &(quant_params), quant_max, quant_min, bit_num, false, &data);
+  } else {
+    bool channel_at_first = (src_tensor->shape().at(0) == channels);
+    ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data_c()), src_tensor->ElementsNum(),
+                                    schema::QuantType_WeightQuant, &(quant_params), quant_max, quant_min, bit_num,
+                                    false, &data, channels, channel_at_first);
+  }
+  if (ret == RET_QUANT_CONTINUE) {
+    MS_LOG(DEBUG) << "No Need to quant per channel";
+    return RET_OK;
+  }
+  if (ret == RET_ERROR) {
+    MS_LOG(ERROR) << "QuantTensorData error,  channels = " << channels;
+    return ret;
+  }
+  if (quant_params.empty()) {
+    MS_LOG(ERROR) << "quant_params empty";
+    return RET_ERROR;
+  }
+  dest_tensor->data = std::vector<uint8_t>(data.data(), data.data() + data.size());
+  dest_tensor->dataType = kNumberTypeInt8;
+  dest_tensor->quantParams.clear();
+  for (auto quant_param : quant_params) {
+    dest_tensor->quantParams.emplace_back(std::make_unique<schema::QuantParamT>(quant_param));
+  }
+
+  return RET_OK;
+}
+
 std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite::Tensor *tensor,
                                                            schema::Tensor *scTensor) {
   auto tensorT = std::make_unique<schema::TensorT>();
   tensorT->nodeType = scTensor->nodeType();
-  tensorT->dataType = tensor->data_type();
   tensorT->dims = tensor->shape();
   tensorT->format = tensor->format();
   tensorT->name = tensor->tensor_name();
   tensorT->refCount = 0;
   tensorT->offset = 0;
+  tensorT->dataType = tensor->data_type();
   tensorT->enableHuffmanCode = false;
   if ((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) {
-    tensorT->data = CreateData(tensor);
-  }
-  for (auto quant_param : tensor->quant_params()) {
-    auto quantParamT = std::make_unique<schema::QuantParamT>();
-    quantParamT->scale = quant_param.scale;
-    quantParamT->zeroPoint = quant_param.zeroPoint;
-    quantParamT->min = 0;
-    quantParamT->max = 0;
-    quantParamT->narrowRange = true;
-    quantParamT->numBits = quant_param.bitNum;
-    quantParamT->inited = quant_param.inited;
-    quantParamT->varCorr = quant_param.var_corr;
-    quantParamT->meanCorr = quant_param.mean_corr;
-    quantParamT->dstDtype = quant_param.dstDtype;
-    quantParamT->roundType = quant_param.roundType;
-    quantParamT->multiplier = quant_param.multiplier;
-    tensorT->quantParams.emplace_back(std::move(quantParamT));
+    if (NeedQuantization(tensor)) {
+      QuantTensorData(tensorT.get(), tensor);
+    } else {
+      tensorT->data = CreateData(tensor);
+    }
   }
   tensorT->quantClusters = tensor->quant_clusters();
   return tensorT;
@@ -85,7 +132,7 @@ std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel
   cnodeT->inputIndex = inputIndex;
   cnodeT->outputIndex = outputIndex;
   cnodeT->name = kernel->name();
-  cnodeT->quantType = schema::QuantType_QUANT_NONE;
+  cnodeT->quantType = GetNodeQuantType(kernel);
   // find kernel in model
   auto *node = FindNode(kernel);
   if (node == nullptr) {
@@ -132,7 +179,6 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
         MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
         return RET_ERROR;
       }
-      out_set.insert(id);
       auto it = remap.find(id);
       if (it == remap.end()) {
         remap[id] = tensor_idx;
@@ -153,7 +199,7 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
     schema::Tensor *scTensor = model_->all_tensors_.at(id);
     auto tensorT = CreateTensor(tensor, scTensor);
     // find a tensor which is not an output
-    if (out_set.find(id) == out_set.end()) {
+    if (out_set.find(remap[id]) == out_set.end()) {
       if ((tensorT->nodeType == NodeType_ValueNode) && (tensorT->data.size() == 0)) {
         meta_graph->inputIndex.push_back(remap[id]);
       }
@@ -165,7 +211,7 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
     meta_graph->allTensors.emplace_back(std::move(tensorT));
   }
   auto graph = meta_graph.release();
-  int err = SaveToFile(graph, file_name_);
+  int err = Storage::Save(*graph, file_name_);
   if (err != RET_OK) {
     MS_LOG(ERROR) << "failed to save flatbuffer file " << file_name_;
   }
@@ -173,30 +219,5 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
   return err;
 }
 
-int TrainExport::SaveToFile(const schema::MetaGraphT *graph, const std::string &outputPath) {
-  flatbuffers::FlatBufferBuilder builder(1024);
-  auto offset = schema::MetaGraph::Pack(builder, graph);
-  builder.Finish(offset);
-  schema::FinishMetaGraphBuffer(builder, offset);
-  int size = builder.GetSize();
-  auto content = builder.GetBufferPointer();
-  if (content == nullptr) {
-    MS_LOG(ERROR) << "GetBufferPointer nullptr";
-    return RET_ERROR;
-  }
-  if (access((outputPath + ".ms").c_str(), F_OK) == 0) {
-    chmod((outputPath + ".ms").c_str(), S_IWUSR);
-  }
-  std::ofstream output(outputPath + ".ms", std::ofstream::binary);
-  if (!output.is_open()) {
-    MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms";
-    return RET_ERROR;
-  }
-  output.write((const char *)content, size);
-  output.close();
-  chmod((outputPath + ".ms").c_str(), S_IRUSR);
-  return RET_OK;
-}
-
 }  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h
index 41cac802378..2639dc6ccec 100644
--- a/mindspore/lite/src/train/train_export.h
+++ b/mindspore/lite/src/train/train_export.h
@@ -50,7 +50,10 @@ class TrainExport {
   std::unique_ptr<schema::TensorT> CreateTensor(const mindspore::lite::Tensor *tensor, schema::Tensor *scTensor);
   std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::LiteKernel *kernel,
                                               std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex);
-  int SaveToFile(const schema::MetaGraphT *graph, const std::string &outputPath);
+
+  bool NeedQuantization(const mindspore::lite::Tensor *tensor);
+  virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor);
+  mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel);
 };
 };  // namespace lite
 }  // namespace mindspore
diff --git a/mindspore/lite/src/weight_decoder.cc b/mindspore/lite/src/weight_decoder.cc
index 1e396c50b74..b90e8e78d10 100644
--- a/mindspore/lite/src/weight_decoder.cc
+++ b/mindspore/lite/src/weight_decoder.cc
@@ -32,7 +32,7 @@ std::vector<bool> StringToBitVector(const std::string &str) {
 }
 
 STATUS IndexingDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
-  MS_LOG(ERROR) << "un-index weight";
+  MS_LOG(DEBUG) << "un-index weight";
   auto bit_num = src_tensor.quantParams()->Get(0)->numBits();
 
   std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt
index a4f59dcb0ce..d7052173af7 100644
--- a/mindspore/lite/test/CMakeLists.txt
+++ b/mindspore/lite/test/CMakeLists.txt
@@ -159,6 +159,7 @@ set(TEST_LITE_SRC
         ${LITE_DIR}/src/common/utils.cc
         ${LITE_DIR}/src/common/loader_util.cc
         ${LITE_DIR}/src/common/string_util.cc
+        ${LITE_DIR}/src/common/quant_utils.cc
         ${LITE_DIR}/tools/common/flag_parser.cc
         ${LITE_DIR}/tools/benchmark/benchmark.cc
         ${LITE_DIR}/test/st/benchmark_test.cc
@@ -306,6 +307,7 @@ if(SUPPORT_TRAIN)
             ${LITE_DIR}/src/train/train_utils.cc
             ${LITE_DIR}/src/train/transfer_session.cc
             ${LITE_DIR}/src/lite_session.cc
+            ${LITE_DIR}/tools/common/storage.cc
             )
 else()
     set(TEST_LITE_SRC
@@ -397,6 +399,10 @@ endif()
 add_executable(lite-test ${TEST_SRC})
 add_dependencies(lite-test fbs_src)
 
+if(SUPPORT_TRAIN)
+    add_dependencies(lite-test fbs_inner_src)
+endif()
+
 target_link_libraries(lite-test dl mindspore::gtest)
 
 if(PLATFORM_ARM AND ENABLE_FP16)
diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg
index cd9de523c06..7627535382c 100644
--- a/mindspore/lite/test/models_ms_train.cfg
+++ b/mindspore/lite/test/models_ms_train.cfg
@@ -11,7 +11,7 @@ googlenet
 densenet
 shufflenetv2
 mini_alexnet weight_quant 2
-nin weight_quant 7
+nin weight_quant 9
 lenet weight_quant 5
 mobilenetv1 weight_quant 2
 mobilenetv2 weight_quant 2
diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh
index c259b5ea386..6dc0282caa9 100755
--- a/mindspore/lite/test/run_net_train.sh
+++ b/mindspore/lite/test/run_net_train.sh
@@ -82,22 +82,27 @@ function Run_x86() {
         model_prefix=${line_array[0]}
         model_name=${line_array[0]}'_train'
         accuracy_limit=0.5
+        export_file=""
+        inference_file=""
         if [[ $model_name == \#* ]]; then
           continue
         fi
         if [[ "${line_array[1]}" == "weight_quant" ]]; then
             model_name=${line_array[0]}'_train_quant'
             accuracy_limit=${line_array[2]}
+        else
+            export_file="${ms_models_path}/${model_name}_tod"
+            rm -f ${export_file}"*"
         fi
-        if [[ "${save_lite}" == "1" ]]; then
-          inference_file="${ms_models_path}/${model_name}_infer"
-        fi
+        inference_file="${ms_models_path}/${model_name}_infer"
+        rm -f ${inference_file}"*"
         echo ${model_name} >> "${run_x86_log_file}"
         ${run_valgrind}./tools/benchmark_train/benchmark_train \
         --modelFile=${ms_models_path}/${model_name}.ms \
-        --inDataFile=${train_io_path}/${model_prefix}_input1.bin,${train_io_path}/${model_prefix}_input2.bin \
+        --inDataFile=${train_io_path}/${model_prefix}_input \
         --expectedDataFile=${train_io_path}/${model_prefix}_output --epochs=${epoch_num} --numThreads=${threads} \
-        --accuracyThreshold=${accuracy_limit} --inferenceFile=${inference_file} >> "${run_x86_log_file}"
+        --accuracyThreshold=${accuracy_limit} --inferenceFile=${inference_file} \
+        --exportFile=${export_file} >> "${run_x86_log_file}"
         if [ $? = 0 ]; then
             run_result='x86: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
         else
@@ -168,21 +173,22 @@ function Run_arm() {
         model_prefix=${line_array[0]}
         model_name=${line_array[0]}'_train'
         accuracy_limit=0.5
+        export_file=""
         if [[ $model_name == \#* ]]; then
             continue
         fi
         if [[ "${line_array[1]}" == "weight_quant" ]]; then
             model_name=${line_array[0]}'_train_quant'
             accuracy_limit=${line_array[2]}
+        else
+            export_file="${tmp_dir}/${model_name}_tod"
         fi
+        inference_file="${tmp_dir}/${model_name}_infer"
 
         if [[ "${line_array[1]}" == "noarm32" ]] && [[ "$1" == arm32 ]]; then
             run_result=$1': '${model_name}' irrelevant'; echo ${run_result} >> ${run_benchmark_train_result_file}
             continue
         fi
-        if [[ "${save_lite}" == "1" ]]; then
-          inference_file="${ms_models_path}/${model_name}_infer"
-        fi
         # run benchmark_train test without clib data
         echo ${model_name} >> "${run_arm_log_file}"
         adb -s ${device_id} push ${train_io_path}/${model_prefix}_input*.bin ${train_io_path}/${model_prefix}_output*.bin  /data/local/tmp/benchmark_train_test >> ${adb_push_log_file}
@@ -193,15 +199,20 @@ function Run_arm() {
         elif [ "$1" == arm32 ]; then
             echo 'cp  /data/local/tmp/arm32/libc++_shared.so ./' >> ${adb_cmd_run_file}
         fi 
-        echo "rm -f ${tmp_dir}/${model_name}_exported.ms" >> ${run_arm_log_file}
-        echo "rm -f ${tmp_dir}/${model_name}_exported.ms" >> ${adb_cmd_run_file}
+        adb -s ${device_id} shell < ${adb_cmd_run_file} >> ${run_arm_log_file}
+        echo "rm -f ${export_file} ${inference_file}.ms" >> ${run_arm_log_file}
+        echo "rm -f ${export_file} ${inference_file}.ms" >> ${adb_cmd_run_file}
+        adb -s ${device_id} shell < ${adb_cmd_run_file} >> ${run_arm_log_file}
         adb_cmd=$(cat <<-ENDM
         export LD_LIBRARY_PATH=./:/data/local/tmp/:/data/local/tmp/benchmark_train_test;./benchmark_train \
         --epochs=${epoch_num} \
         --modelFile=${model_name}.ms \
-        --inDataFile=${tmp_dir}/${model_prefix}_input1.bin,${tmp_dir}/${model_prefix}_input2.bin \
+        --inDataFile=${tmp_dir}/${model_prefix}_input \
         --expectedDataFile=${tmp_dir}/${model_prefix}_output \
-        --numThreads=${threads} --accuracyThreshold=${accuracy_limit} --inferenceFile=${inference_file}
+        --numThreads=${threads} \
+        --accuracyThreshold=${accuracy_limit} \
+        --inferenceFile=${inference_file} \
+        --exportFile=${export_file}
 ENDM
         )
         echo "${adb_cmd}" >> ${run_arm_log_file}
@@ -252,7 +263,7 @@ models_mindspore_train_config=${basepath}/models_ms_train.cfg
 epoch_num=1
 threads=2
 train_io_path=""
-while getopts "r:M:c:m:d:i:e:vt:q:DF" opt; do
+while getopts "r:M:c:m:d:i:e:vt:q:D" opt; do
     case ${opt} in
         r)
            release_path=${OPTARG}
@@ -295,8 +306,6 @@ while getopts "r:M:c:m:d:i:e:vt:q:DF" opt; do
             epoch_num=${OPTARG}
             echo "train epoch num is ${epoch_num}"
             ;;
-        F)  save_lite=1
-            ;;                          
         ?)
             echo "unknown para"
             exit 1;;
diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc
index 5ffb0a37292..b3858d85ddd 100644
--- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc
+++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc
@@ -107,7 +107,7 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
     int bit_num = tensor_input->quantParams.at(0)->numBits;
     // Pack Repetition
     auto repetition_packed = false;
-    MS_LOG(ERROR) << dst_node->name;
+    MS_LOG(DEBUG) << dst_node->name;
     if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
       if (bit_num <= 8) {
         repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
index 2271defc52b..5b494d5ed2d 100644
--- a/mindspore/lite/tools/benchmark_train/net_train.cc
+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
@@ -32,8 +32,6 @@
 
 namespace mindspore {
 namespace lite {
-static const char *DELIM_COLON = ":";
-static const char *DELIM_COMMA = ",";
 static const char *DELIM_SLASH = "/";
 
 namespace {
@@ -81,8 +79,8 @@ int NetTrain::GenerateRandomData(size_t size, void *data) {
   return RET_OK;
 }
 
-int NetTrain::GenerateInputData() {
-  for (auto tensor : ms_inputs_) {
+int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
+  for (auto tensor : *ms_inputs) {
     MS_ASSERT(tensor != nullptr);
     auto input_data = tensor->MutableData();
     if (input_data == nullptr) {
@@ -100,16 +98,16 @@ int NetTrain::GenerateInputData() {
   return RET_OK;
 }
 
-int NetTrain::LoadInput() {
+int NetTrain::LoadInput(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
   if (flags_->in_data_file_.empty()) {
-    auto status = GenerateInputData();
+    auto status = GenerateInputData(ms_inputs);
     if (status != RET_OK) {
       std::cerr << "Generate input data error " << status << std::endl;
       MS_LOG(ERROR) << "Generate input data error " << status;
       return status;
     }
   } else {
-    auto status = ReadInputFile();
+    auto status = ReadInputFile(ms_inputs);
     if (status != RET_OK) {
       std::cerr << "ReadInputFile error, " << status << std::endl;
       MS_LOG(ERROR) << "ReadInputFile error, " << status;
@@ -119,8 +117,8 @@ int NetTrain::LoadInput() {
   return RET_OK;
 }
 
-int NetTrain::ReadInputFile() {
-  if (ms_inputs_.empty()) {
+int NetTrain::ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
+  if (ms_inputs->empty()) {
     return RET_OK;
   }
 
@@ -128,16 +126,12 @@ int NetTrain::ReadInputFile() {
     MS_LOG(ERROR) << "Not supported image input";
     return RET_ERROR;
   } else {
-    if (ms_inputs_.size() > flags_->input_data_list_.size()) {
-      MS_LOG(ERROR) << "missing input files expecting " << ms_inputs_.size() << ",got "
-                    << flags_->input_data_list_.size();
-      return RET_ERROR;
-    }
-    for (size_t i = 0; i < ms_inputs_.size(); i++) {
-      auto cur_tensor = ms_inputs_.at(i);
+    for (size_t i = 0; i < ms_inputs->size(); i++) {
+      auto cur_tensor = ms_inputs->at(i);
       MS_ASSERT(cur_tensor != nullptr);
       size_t size;
-      char *bin_buf = ReadFile(flags_->input_data_list_[i].c_str(), &size);
+      std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
+      char *bin_buf = ReadFile(file_name.c_str(), &size);
       if (bin_buf == nullptr) {
         MS_LOG(ERROR) << "ReadFile return nullptr";
         return RET_ERROR;
@@ -158,94 +152,12 @@ int NetTrain::ReadInputFile() {
   return RET_OK;
 }
 
-int NetTrain::CompareOutput() {
-  std::cout << "================ Comparing Output data ================" << std::endl;
-  float total_bias = 0;
-  int total_size = 0;
-  bool has_error = false;
-  auto tensors_list = session_->GetOutputs();
-  if (tensors_list.empty()) {
-    MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
-    return RET_ERROR;
-  }
-  mindspore::tensor::MSTensor *tensor = nullptr;
-  int i = 1;
-  for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
-    tensor = session_->GetOutputByTensorName(it->first);
-    std::cout << "output is tensor " << it->first << "\n";
-    auto outputs = tensor->MutableData();
-    size_t size;
-    std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
-    auto *bin_buf = ReadFileBuf(output_file.c_str(), &size);
-    if (bin_buf == nullptr) {
-      MS_LOG(ERROR) << "ReadFile return nullptr";
-      return RET_ERROR;
-    }
-
-    if (flags_->enable_fp16_ && tensor->data_type() == kNumberTypeFloat16) {
-      if (static_cast<int>(size / sizeof(float)) != tensor->ElementsNum()) {
-        MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
-                      << ", read size: " << size / sizeof(float);
-        return RET_ERROR;
-      }
-    } else {
-      if (size != tensor->Size()) {
-        MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
-                      << ", read size: " << size;
-        return RET_ERROR;
-      }
-    }
-    float bias = 0.f;
-    if (flags_->enable_fp16_ && tensor->data_type() == kNumberTypeFloat16) {
-#ifdef ENABLE_FP16
-      bias = CompareData<float16_t>(bin_buf, tensor->ElementsNum(), reinterpret_cast<float16_t *>(outputs));
-#endif
-    } else {
-      bias = CompareData<float>(bin_buf, tensor->ElementsNum(), reinterpret_cast<float *>(outputs));
-    }
-    if (bias >= 0) {
-      total_bias += bias;
-      total_size++;
-    } else {
-      has_error = true;
-      break;
-    }
-    i++;
-    delete[] bin_buf;
-  }
-
-  if (!has_error) {
-    float mean_bias;
-    if (total_size != 0) {
-      mean_bias = total_bias / total_size * 100;
-    } else {
-      mean_bias = 0;
-    }
-
-    std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
-              << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
-    std::cout << "=======================================================" << std::endl << std::endl;
-
-    if (mean_bias > this->flags_->accuracy_threshold_) {
-      MS_LOG(ERROR) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
-      std::cerr << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
-      return RET_ERROR;
-    } else {
-      return RET_OK;
-    }
-  } else {
-    MS_LOG(ERROR) << "Error in CompareData";
-    std::cerr << "Error in CompareData" << std::endl;
-    std::cout << "=======================================================" << std::endl << std::endl;
-    return RET_ERROR;
-  }
-}
-int NetTrain::CompareOutputLite(const std::unique_ptr<session::LiteSession> &lite_session) {
+int NetTrain::CompareOutput(const session::LiteSession &lite_session) {
   std::cout << "================ Comparing Forward Output data ================" << std::endl;
   float total_bias = 0;
   int total_size = 0;
   bool has_error = false;
-  auto tensors_list = lite_session->GetOutputs();
+  auto tensors_list = lite_session.GetOutputs();
   if (tensors_list.empty()) {
     MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
     return RET_ERROR;
@@ -253,9 +165,9 @@ int NetTrain::CompareOutputLite(const std::unique_ptr<session::LiteSession> &lit
   mindspore::tensor::MSTensor *tensor = nullptr;
   int i = 1;
   for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
-    tensor = lite_session->GetOutputByTensorName(it->first);
+    tensor = lite_session.GetOutputByTensorName(it->first);
     std::cout << "output is tensor " << it->first << "\n";
-    auto outputs = tensor->MutableData();
+    auto outputs = tensor->data();
     size_t size;
     std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
     auto *bin_buf = ReadFileBuf(output_file.c_str(), &size);
@@ -307,7 +219,7 @@ int NetTrain::CompareOutputLite(const std::unique_ptr<session::LiteSession> &lit
   }
 }
 
-int NetTrain::MarkPerformance() {
+int NetTrain::MarkPerformance(session::TrainSession *session) {
   MS_LOG(INFO) << "Running train loops...";
   std::cout << "Running train loops..." << std::endl;
   uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
@@ -315,10 +227,10 @@ int NetTrain::MarkPerformance() {
   uint64_t time_avg = 0;
 
   for (int i = 0; i < flags_->epochs_; i++) {
-    session_->BindThread(true);
+    session->BindThread(true);
     auto start = GetTimeUs();
     auto status =
-      flags_->time_profiling_ ? session_->RunGraph(before_call_back_, after_call_back_) : session_->RunGraph();
+      flags_->time_profiling_ ? session->RunGraph(before_call_back_, after_call_back_) : session->RunGraph();
     if (status != 0) {
       MS_LOG(ERROR) << "Inference error " << status;
       std::cerr << "Inference error " << status;
@@ -330,7 +242,7 @@ int NetTrain::MarkPerformance() {
     time_min = std::min(time_min, time);
     time_max = std::max(time_max, time);
     time_avg += time;
-    session_->BindThread(false);
+    session->BindThread(false);
   }
 
   if (flags_->time_profiling_) {
@@ -352,10 +264,9 @@ int NetTrain::MarkPerformance() {
   return RET_OK;
 }
 
-int NetTrain::MarkAccuracy() {
+int NetTrain::MarkAccuracy(session::LiteSession *session) {
   MS_LOG(INFO) << "MarkAccuracy";
-  std::cout << "MarkAccuracy" << std::endl;
-  for (auto &msInput : ms_inputs_) {
+  for (auto &msInput : session->GetInputs()) {
     switch (msInput->data_type()) {
       case TypeId::kNumberTypeFloat:
         PrintInputData<float>(msInput);
@@ -371,50 +282,14 @@ int NetTrain::MarkAccuracy() {
         return RET_ERROR;
     }
   }
-  session_->Eval();
-
-  auto status = session_->RunGraph(before_call_back_, after_call_back_);
+  auto status = session->RunGraph();
   if (status != RET_OK) {
     MS_LOG(ERROR) << "Inference error " << status;
     std::cerr << "Inference error " << status << std::endl;
     return status;
   }
 
-  status = CompareOutput();
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "Compare output error " << status;
-    std::cerr << "Compare output error " << status << std::endl;
-    return status;
-  }
-  return RET_OK;
-}
-int NetTrain::MarkAccuracyLite(const std::unique_ptr<session::LiteSession> &lite_session) {
-  MS_LOG(INFO) << "MarkAccuracy";
-  std::cout << "MarkAccuracy" << std::endl;
-  for (auto &msInput : ms_inputs_) {
-    switch (msInput->data_type()) {
-      case TypeId::kNumberTypeFloat:
-        PrintInputData<float>(msInput);
-        break;
-      case TypeId::kNumberTypeFloat32:
-        PrintInputData<float>(msInput);
-        break;
-      case TypeId::kNumberTypeInt32:
-        PrintInputData<int>(msInput);
-        break;
-      default:
-        MS_LOG(ERROR) << "Datatype " << msInput->data_type() << " is not supported.";
-        return RET_ERROR;
-    }
-  }
-  auto status = lite_session->RunGraph();
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "Inference error " << status;
-    std::cerr << "Inference error " << status << std::endl;
-    return status;
-  }
-
-  status = CompareOutputLite(lite_session);
+  status = CompareOutput(*session);
   if (status != RET_OK) {
     MS_LOG(ERROR) << "Compare output error " << status;
     std::cerr << "Compare output error " << status << std::endl;
@@ -423,228 +298,106 @@ int NetTrain::MarkAccuracyLite(const std::unique_ptr<session::LiteSession> &lite
   return RET_OK;
 }
 
-int NetTrain::RunExportedNet() {
+static CpuBindMode FlagToBindMode(int flag) {
+  if (flag == 2) {
+    return MID_CPU;
+  }
+  if (flag == 1) {
+    return HIGHER_CPU;
+  }
+  return NO_BIND;
+}
+
+int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session, int epochs) {
   auto start_prepare_time = GetTimeUs();
-  // Load graph
-  std::string model_name = flags_->export_file_.substr(flags_->export_file_.find_last_of(DELIM_SLASH) + 1);
+  std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
+  Context context;
+  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = FlagToBindMode(flags_->cpu_bind_mode_);
+  context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
+  context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
+  context.thread_num_ = flags_->num_threads_;
 
-  MS_LOG(INFO) << "start reading exported model file";
-  std::cout << "start reading exported model file" << std::endl;
-  auto context = std::make_shared<Context>();
-  if (context == nullptr) {
-    MS_LOG(ERROR) << "New context failed while running " << model_name.c_str();
-    std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
-  }
-
-  if (flags_->cpu_bind_mode_ == 2) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
-  } else if (flags_->cpu_bind_mode_ == 1) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
-  } else {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
-  }
-
-  context->thread_num_ = flags_->num_threads_;
-
-  auto *model = mindspore::lite::Model::Import(flags_->export_file_.c_str());
+  MS_LOG(INFO) << "start reading model file" << filename.c_str();
+  std::cout << "start reading model file " << filename.c_str() << std::endl;
+  auto *model = mindspore::lite::Model::Import(filename.c_str());
   if (model == nullptr) {
     MS_LOG(ERROR) << "create model for train session failed";
     return RET_ERROR;
   }
 
-  session_ = session::TrainSession::CreateSession(model, context.get());
-  if (session_ == nullptr) {
-    MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
-    std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
-  }
-  if (flags_->loss_name_ != "") {
-    session_->SetLossName(flags_->loss_name_);
-  }
-  ms_inputs_ = session_->GetInputs();
-  auto end_prepare_time = GetTimeUs();
-  MS_LOG(INFO) << "Exported model PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
-  std::cout << "Exported model PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
-
-  // Load input
-  MS_LOG(INFO) << "start generate input data";
-  auto status = LoadInput();
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "Generate input data error";
-    return status;
-  }
-
-  if (!flags_->data_file_.empty()) {
-    MS_LOG(INFO) << "Check accuracy for exported model";
-    std::cout << "Check accuracy for exported model " << std::endl;
-    status = MarkAccuracy();
-    for (auto &data : data_) {
-      data.second->shape.clear();
-      data.second->data.clear();
-      delete data.second;
-    }
-    data_.clear();
-    if (status != RET_OK) {
-      MS_LOG(ERROR) << "Run MarkAccuracy on exported model error: " << status;
-      std::cout << "Run MarkAccuracy on exported model error: " << status << std::endl;
-      return status;
-    }
-  }
-  return RET_OK;
-}
-
-int NetTrain::RunExportedNetLite(std::string file_name) {
-  auto start_prepare_time = GetTimeUs();
-  // Load graph
-  std::string model_name = file_name.substr(file_name.find_last_of(DELIM_SLASH) + 1);
-
-  MS_LOG(INFO) << "start reading exported model file";
-  std::cout << "reading " << file_name << std::endl;
-  auto context = std::make_shared<Context>();
-  if (context == nullptr) {
-    MS_LOG(ERROR) << "New context failed while running " << model_name.c_str();
-    std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
-  }
-
-  if (flags_->cpu_bind_mode_ == 2) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
-  } else if (flags_->cpu_bind_mode_ == 1) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
-  } else {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
-  }
-
-  context->thread_num_ = flags_->num_threads_;
-
-  auto *model = mindspore::lite::Model::Import(file_name.c_str());
-  if (model == nullptr) {
-    MS_LOG(ERROR) << "create model for lite session failed";
-    return RET_ERROR;
-  }
-  auto lite_session = std::unique_ptr<session::LiteSession>(session::LiteSession::CreateSession(context.get()));
-  if (lite_session == nullptr) {
-    MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
-    std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
-  }
-  if (lite_session->CompileGraph(model) != RET_OK) {
-    MS_LOG(ERROR) << "Cannot compile model";
-    delete model;
-    return RET_ERROR;
-  }
-  ms_inputs_ = lite_session->GetInputs();
-  auto end_prepare_time = GetTimeUs();
-  MS_LOG(INFO) << "Exported model PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
-  std::cout << "Exported model PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
-
-  // Load input
-  MS_LOG(INFO) << "start generate input data";
-  auto status = LoadInput();
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "Generate input data error";
-    delete model;
-    return status;
-  }
-  if (!flags_->data_file_.empty()) {
-    MS_LOG(INFO) << "Check accuracy for exported model";
-    std::cout << "Check accuracy for exported model " << std::endl;
-    status = MarkAccuracyLite(lite_session);
-    for (auto &data : data_) {
-      data.second->shape.clear();
-      data.second->data.clear();
-      delete data.second;
-    }
-    data_.clear();
-    if (status != RET_OK) {
-      MS_LOG(ERROR) << "Run MarkAccuracy on exported model error: " << status;
-      std::cout << "Run MarkAccuracy on exported model error: " << status << std::endl;
+  session::LiteSession *session = nullptr;
+  session::TrainSession *t_session = nullptr;
+  if (train_session) {
+    t_session = session::TrainSession::CreateSession(model, &context);
+    if (t_session == nullptr) {
+      MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
+      std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
       delete model;
-      return status;
+      return RET_ERROR;
     }
-  }
-  delete model;
-  return RET_OK;
-}
 
-int NetTrain::RunNetTrain() {
-  auto start_prepare_time = GetTimeUs();
-  // Load graph
-  std::string model_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
-
-  MS_LOG(INFO) << "start reading model file";
-  std::cout << "start reading model file" << std::endl;
-  auto context = std::make_shared<Context>();
-  if (context == nullptr) {
-    MS_LOG(ERROR) << "New context failed while running " << model_name.c_str();
-    std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
-  }
-
-  if (flags_->cpu_bind_mode_ == 2) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
-  } else if (flags_->cpu_bind_mode_ == 1) {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
+    if (flags_->loss_name_ != "") {
+      t_session->SetLossName(flags_->loss_name_);
+    }
+    if (epochs > 0) {
+      t_session->Train();
+    }
+    session = t_session;
   } else {
-    context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
-  }
-  context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
-  layer_checksum_ = flags_->layer_checksum_;
-  context->thread_num_ = flags_->num_threads_;
-
-  auto *model = mindspore::lite::Model::Import(flags_->model_file_.c_str());
-  if (model == nullptr) {
-    MS_LOG(ERROR) << "create model for train session failed";
-    return RET_ERROR;
-  }
-  session_ = session::TrainSession::CreateSession(model, context.get());
-  if (session_ == nullptr) {
-    MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
-    std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
-    return RET_ERROR;
+    session = session::LiteSession::CreateSession(&context);
+    if (session == nullptr) {
+      MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
+      std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl;
+      delete model;
+      return RET_ERROR;
+    }
+    if (session->CompileGraph(model) != RET_OK) {
+      MS_LOG(ERROR) << "Cannot compile model";
+      delete model;
+      return RET_ERROR;
+    }
+    delete model;
   }
 
-  if (flags_->loss_name_ != "") {
-    session_->SetLossName(flags_->loss_name_);
-  }
-  session_->Train();
-
-  ms_inputs_ = session_->GetInputs();
   auto end_prepare_time = GetTimeUs();
   MS_LOG(INFO) << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
   std::cout << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
-
   // Load input
-  MS_LOG(INFO) << "start generate input data";
-  auto status = LoadInput();
+  MS_LOG(INFO) << "Load input data";
+  auto ms_inputs = session->GetInputs();
+  auto status = LoadInput(&ms_inputs);
   if (status != RET_OK) {
-    MS_LOG(ERROR) << "Generate input data error";
+    MS_LOG(ERROR) << "Load input data error";
     return status;
   }
-  if (flags_->epochs_ > 0) {
-    status = MarkPerformance();
+
+  if ((epochs > 0) && (t_session != nullptr)) {
+    status = MarkPerformance(t_session);
     if (status != RET_OK) {
       MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
       std::cout << "Run MarkPerformance error: " << status << std::endl;
       return status;
     }
+    SaveModels(t_session, model);  // save file if flags are on
   }
   if (!flags_->data_file_.empty()) {
-    status = MarkAccuracy();
-    for (auto &data : data_) {
-      data.second->shape.clear();
-      data.second->data.clear();
-      delete data.second;
+    if (t_session != nullptr) {
+      t_session->Eval();
     }
-    data_.clear();
+    status = MarkAccuracy(session);
     if (status != RET_OK) {
       MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
       std::cout << "Run MarkAccuracy error: " << status << std::endl;
       return status;
     }
   }
-  status = CheckExecute(model);
+  return RET_OK;
+}
+
+int NetTrain::RunNetTrain() {
+  CreateAndRunNetwork(flags_->model_file_, true, flags_->epochs_);
+
+  auto status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
   if (status != RET_OK) {
     MS_LOG(ERROR) << "Run CheckExecute error: " << status;
     std::cout << "Run CheckExecute error: " << status << std::endl;
@@ -653,8 +406,7 @@ int NetTrain::RunNetTrain() {
   return RET_OK;
 }
 
-int NetTrain::CheckExecute(mindspore::lite::Model *model) {
-  int status;
+int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model *model) {
   if (!flags_->export_file_.empty()) {
     auto ret = Model::Export(model, flags_->export_file_.c_str());
     if (ret != RET_OK) {
@@ -662,67 +414,39 @@ int NetTrain::CheckExecute(mindspore::lite::Model *model) {
       std::cout << "Run SaveToFile error";
       return RET_ERROR;
     }
-    delete session_;
-    session_ = nullptr;
-    status = RunExportedNet();
+  }
+  if (!flags_->inference_file_.empty()) {
+    auto tick = GetTimeUs();
+    auto status = session->ExportInference(flags_->inference_file_);
     if (status != RET_OK) {
-      MS_LOG(ERROR) << "Run Exported model error: " << status;
-      std::cout << "Run Exported model error: " << status << std::endl;
+      MS_LOG(ERROR) << "Save model error: " << status;
+      std::cout << "Save model error: " << status << std::endl;
       return status;
     }
-  } else {
-    if (!flags_->inference_file_.empty()) {
-      auto tick = GetTimeUs();
-      status = session_->ExportInference(flags_->inference_file_);
-      if (status != RET_OK) {
-        MS_LOG(ERROR) << "Save model error: " << status;
-        std::cout << "Save model error: " << status << std::endl;
-        return status;
-      }
-      std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
-      delete session_;
-      session_ = nullptr;
-
-      status = RunExportedNetLite(flags_->inference_file_ + ".ms");
-      if (status != RET_OK) {
-        MS_LOG(ERROR) << "Running saved model error: " << status;
-        std::cout << "Running saved model error: " << status << std::endl;
-        return status;
-      }
-    }
+    std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
   }
   return RET_OK;
 }
 
-void NetTrainFlags::InitInputDataList() {
-  char *saveptr1 = nullptr;
-  char *input_list = new char[this->in_data_file_.length() + 1];
-  snprintf(input_list, this->in_data_file_.length() + 1, "%s", this->in_data_file_.c_str());
-  const char *split_c = ",";
-  char *cur_input = strtok_r(input_list, split_c, &saveptr1);
-  while (cur_input != nullptr) {
-    input_data_list_.emplace_back(cur_input);
-    cur_input = strtok_r(nullptr, split_c, &saveptr1);
-  }
-  delete[] input_list;
-}
-
-void NetTrainFlags::InitResizeDimsList() {
-  std::string content;
-  content = this->resize_dims_in_;
-  std::vector<int64_t> shape;
-  auto shape_strs = StringSplit(content, std::string(DELIM_COLON));
-  for (const auto &shape_str : shape_strs) {
-    shape.clear();
-    auto dim_strs = StringSplit(shape_str, std::string(DELIM_COMMA));
-    std::cout << "Resize Dims: ";
-    for (const auto &dim_str : dim_strs) {
-      std::cout << dim_str << " ";
-      shape.emplace_back(static_cast<int64_t>(std::stoi(dim_str)));
+int NetTrain::CheckExecutionOfSavedModels() {
+  int status = RET_OK;
+  if (!flags_->export_file_.empty()) {
+    status = NetTrain::CreateAndRunNetwork(flags_->export_file_, true, 0);
+    if (status != RET_OK) {
+      MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
+      std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
+      return status;
     }
-    std::cout << std::endl;
-    this->resize_dims_.emplace_back(shape);
   }
+  if (!flags_->inference_file_.empty()) {
+    status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + ".ms", false, 0);
+    if (status != RET_OK) {
+      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
+      std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
+      return status;
+    }
+  }
+  return status;
 }
 
 int NetTrain::InitCallbackParameter() {
@@ -766,7 +490,7 @@ int NetTrain::InitCallbackParameter() {
     op_times_by_type_[call_param.node_type].second += cost;
     op_times_by_name_[call_param.node_name].first++;
     op_times_by_name_[call_param.node_name].second += cost;
-    if (layer_checksum_) {
+    if (flags_->layer_checksum_) {
       auto out_tensor = after_outputs.at(0);
       void *output = out_tensor->MutableData();
       int tensor_size = out_tensor->ElementsNum();
@@ -841,13 +565,6 @@ int NetTrain::Init() {
     std::cerr << "modelPath is required" << std::endl;
     return 1;
   }
-  flags_->InitInputDataList();
-  flags_->InitResizeDimsList();
-  if (!flags_->resize_dims_.empty() && flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
-    MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
-    std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
-    return RET_ERROR;
-  }
 
   if (flags_->time_profiling_) {
     auto status = InitCallbackParameter();
@@ -925,14 +642,6 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
   return RET_OK;
 }
 
-NetTrain::~NetTrain() {
-  for (auto iter : this->data_) {
-    delete (iter.second);
-  }
-  this->data_.clear();
-  if (session_ != nullptr) delete (session_);
-}
-
 int RunNetTrain(int argc, const char **argv) {
   NetTrainFlags flags;
   Option<std::string> err = flags.ParseFlags(argc, argv);
diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h
index 252f3d31be8..cd62d1d817d 100644
--- a/mindspore/lite/tools/benchmark_train/net_train.h
+++ b/mindspore/lite/tools/benchmark_train/net_train.h
@@ -42,15 +42,6 @@ enum MS_API DataType { kImage = 0, kBinary = 1 };
 constexpr float relativeTolerance = 1e-5;
 constexpr float absoluteTolerance = 1e-8;
 
-struct MS_API CheckTensor {
-  CheckTensor(const std::vector<size_t> &shape, const std::vector<float> &data) {
-    this->shape = shape;
-    this->data = data;
-  }
-  std::vector<size_t> shape;
-  std::vector<float> data;
-};
-
 template <typename T>
 float TensorSum(void *data, int size) {
   T *typed_data = reinterpret_cast<T *>(data);
@@ -84,10 +75,6 @@ class MS_API NetTrainFlags : public virtual FlagParser {
 
   ~NetTrainFlags() override = default;
 
-  void InitInputDataList();
-
-  void InitResizeDimsList();
-
  public:
   // common
   std::string model_file_;
@@ -118,25 +105,22 @@ class MS_API NetTrainFlags : public virtual FlagParser {
 class MS_API NetTrain {
  public:
   explicit NetTrain(NetTrainFlags *flags) : flags_(flags) {}
-
-  virtual ~NetTrain();
+  virtual ~NetTrain() = default;
 
   int Init();
   int RunNetTrain();
-  int RunExportedNet();
 
  private:
   // call GenerateInputData or ReadInputFile to init inputTensors
-  int LoadInput();
+  int LoadInput(Vector<tensor::MSTensor *> *ms_inputs);
 
   // call GenerateRandomData to fill inputTensors
-  int GenerateInputData();
+  int GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
 
   int GenerateRandomData(size_t size, void *data);
 
-  int ReadInputFile();
-
-  int CompareOutput();
+  int ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
+  int CreateAndRunNetwork(const std::string &filename, int train_session, int epochs);
 
   int InitCallbackParameter();
 
@@ -208,22 +192,13 @@ class MS_API NetTrain {
     return meanError;
   }
 
-  int MarkPerformance();
+  int MarkPerformance(session::TrainSession *session);
 
-  int MarkAccuracy();
-
- private:
-  int RunExportedNetLite(std::string file_name);
-  int MarkAccuracyLite(const std::unique_ptr<session::LiteSession> &lite_session);
-  int CompareOutputLite(const std::unique_ptr<session::LiteSession> &lite_session);
-  int CheckExecute(mindspore::lite::Model *model);
+  int MarkAccuracy(session::LiteSession *lite_session);
+  int CompareOutput(const session::LiteSession &lite_session);
+  int SaveModels(session::TrainSession *session, mindspore::lite::Model *model);
+  int CheckExecutionOfSavedModels();
   NetTrainFlags *flags_;
-  session::TrainSession *session_ = nullptr;
-  std::vector<mindspore::tensor::MSTensor *> ms_inputs_;
-  std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> ms_outputs_;
-  std::unordered_map<std::string, CheckTensor *> data_;
-  std::unordered_map<std::string, TypeId> data_type_map_{{"FLOAT", TypeId::kNumberTypeFloat},
-                                                         {"INT32", TypeId::kNumberTypeInt32}};
 
   // callback parameters
   uint64_t op_begin_ = 0;
@@ -234,7 +209,6 @@ class MS_API NetTrain {
 
   mindspore::KernelCallBack before_call_back_;
   mindspore::KernelCallBack after_call_back_;
-  bool layer_checksum_ = false;
 };
 
 int MS_API RunNetTrain(int argc, const char **argv);
diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h
index 6c412c84f39..bf05ad0e58a 100644
--- a/mindspore/lite/tools/common/graph_util.h
+++ b/mindspore/lite/tools/common/graph_util.h
@@ -143,7 +143,7 @@ bool IndexingCompress(const std::set<T> &quant_data_set, const std::map<T, size_
   tensor->data.resize(new_data_str.size());
 
   tensor->weightQunatCompressType = schema::WeightQunatCompressType_INDEXING;
-  MS_LOG(ERROR) << "set WeightQunatCompressType_INDEXING";
+  MS_LOG(DEBUG) << "set WeightQunatCompressType_INDEXING";
   return true;
 }
 
@@ -285,21 +285,21 @@ bool PackRepetition(size_t bit_num, schema::TensorT *tensor) {
   auto pack_sparsity_size_in_bit =
     1 * 8 + 4 * 8 + bit_num + bit_num * unique_value_cnt + unique_value_bit * nz_cnt + nz_cnt * coor_best_bit;
   size_t pack_sparsity_size_in_byte = ceil(pack_sparsity_size_in_bit / 8.0);
-  MS_LOG(ERROR) << "coor_best_bit: " << coor_best_bit << " ori: " << origin_size_in_byte
+  MS_LOG(DEBUG) << "coor_best_bit: " << coor_best_bit << " ori: " << origin_size_in_byte
                 << " indexing: " << pack_repetition_size_in_byte << " sparse: " << pack_sparsity_size_in_byte;
   auto min_byte_need = std::min({origin_size_in_byte, pack_repetition_size_in_byte, pack_sparsity_size_in_byte});
   if (min_byte_need == origin_size_in_byte) {
     return false;
   } else if (min_byte_need == pack_repetition_size_in_byte) {
-    MS_LOG(ERROR) << "from " << origin_size_in_byte << " to " << pack_repetition_size_in_byte;
+    MS_LOG(DEBUG) << "from " << origin_size_in_byte << " to " << pack_repetition_size_in_byte;
     return IndexingCompress<T>(quant_data_set, unique_value_index_map, unique_value_bit, unique_value_cnt,
                                pack_repetition_size_in_byte, bit_num, tensor);
   } else if (min_byte_need == pack_sparsity_size_in_byte) {
-    MS_LOG(ERROR) << "from " << origin_size_in_byte << " to " << pack_sparsity_size_in_byte;
+    MS_LOG(DEBUG) << "from " << origin_size_in_byte << " to " << pack_sparsity_size_in_byte;
     return SparsityCompress<T>(quant_data_set, unique_value_index_map, unique_value_bit, unique_value_cnt,
                                pack_sparsity_size_in_byte, nz_cnt, coor_best_bit, bit_num, tensor);
   } else {
-    MS_LOG(ERROR) << "unexpected: " << min_byte_need << " not in {" << origin_size_in_byte << " "
+    MS_LOG(DEBUG) << "unexpected: " << min_byte_need << " not in {" << origin_size_in_byte << " "
                   << pack_repetition_size_in_byte << " " << pack_sparsity_size_in_byte << "}";
   }
   return false;
diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt
index 1542c493a56..73d2c6e71fb 100644
--- a/mindspore/lite/tools/converter/CMakeLists.txt
+++ b/mindspore/lite/tools/converter/CMakeLists.txt
@@ -22,6 +22,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
         ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
+        ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
         ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
index 73c70b66d0b..8819d7d786d 100644
--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc
@@ -14,14 +14,16 @@
  * limitations under the License.
  */
 
+#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
 #include <vector>
 #include <cmath>
-#include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h"
+#include <algorithm>
 #include "tools/converter/converter_context.h"
 #include "tools/converter/quantizer/quantize_util.h"
 #include "tools/common/tensor_util.h"
 #include "tools/common/graph_util.h"
 #include "tools/common/node_util.h"
+#include "src/common/quant_utils.h"
 
 namespace mindspore::lite {
 namespace {
@@ -49,7 +51,7 @@ STATUS ComputeDataToInt8(const std::unique_ptr<TensorT> &tensor, int32_t index)
       return RET_OK;
     }
     for (size_t j = 0; j < wShapeSize; j++) {
-      qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
+      qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get());
     }
   } else {  // convert uint8 to int8
     auto *weightData = static_cast<uint8_t *>(oriWeightData);
@@ -141,7 +143,7 @@ STATUS ComputeQuantTensorPerChannel(TensorT *tensor, const int &tensor_index, co
         auto *dst_data_int32 = reinterpret_cast<int32_t *>(dst_data.data());
         dst_data_int32[index] = quant_data;
       } else {
-        auto quant_data = quant::QuantizeData<int8_t>(raw_data, tensor->quantParams.at(i).get());
+        auto quant_data = QuantizeData<int8_t>(raw_data, tensor->quantParams.at(i).get());
         dst_data[index] = quant_data;
       }
     }
diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
index e01b4926480..55ef39e4439 100644
--- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
+++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
@@ -44,6 +44,7 @@
 #include "securec/include/securec.h"
 #include "tools/common/tensor_util.h"
 #include "src/common/file_utils.h"
+#include "src/common/quant_utils.h"
 #include "src/common/utils.h"
 #include "tools/converter/quantizer/weight_quantizer.h"
 
@@ -1282,8 +1283,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
     return status;
   }
 
-  if (calibrator_->config_param_.mixed) {
-    // get opname_bit map
+  if (calibrator_->config_param_.mixed) {  // get opname_bit map
     auto weight_quant_func_graph = CopyFuncGraph(func_graph);
     if (weight_quant_func_graph == nullptr) {
       MS_LOG(ERROR) << "CopyFuncGraph error";
@@ -1315,7 +1315,6 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
     MS_LOG(ERROR) << "create session failed!";
     return RET_ERROR;
   }
-
   MS_LOG(INFO) << "start to update divergence's max value";
   status = DoInference();
   if (status != RET_OK) {
@@ -1363,14 +1362,12 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
       MS_LOG(ERROR) << "create session failed!";
       return RET_ERROR;
     }
-
     MS_LOG(INFO) << "do bias correction";
     status = BiasCorrection(func_graph);
     if (status != RET_OK) {
       MS_LOG(WARNING) << "BiasCorrection failed.";
     }
   }
-
   return RET_OK;
 }
 
@@ -1477,7 +1474,7 @@ KernelCallBack PostTrainingQuantizer::GetBeforeCallBack(bool int8_op) {
         quant_param_t.scale = quant_params[0].scale;
         quant_param_t.zeroPoint = quant_params[0].zeroPoint;
         for (auto float_data : fp32_op_input) {
-          auto quant_data = QuantizeData<int8_t>(float_data, quant_param_t, quant_max, quant_min);
+          auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, quant_max, quant_min);
           quant_datas.push_back(quant_data);
         }
 
diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc
index 2d0be9f0586..74f6f46408e 100644
--- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc
+++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc
@@ -100,12 +100,12 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
   return true;
 }
 
-bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
+bool QuantStrategy::CanOpPostQuantized(const AnfNodePtr &node) const {
   MS_ASSERT(node != nullptr);
   if (!node->isa<mindspore::CNode>()) {
     return false;
   }
-  auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
+  const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
   auto type = NodePrimitiveType(cnode);
   static const std::vector<std::string> int8OpList = {
     ops::kNameAddFusion,     ops::kNameActivation,      ops::kNameAvgPoolFusion,
@@ -268,67 +268,6 @@ bool TensorQuantParamsInited(const schema::TensorT &tensor) {
   return true;
 }
 
-STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
-                             int quant_min, int num_bits) {
-  MS_ASSERT(quantParam != nullptr);
-  if (mMin > 0.0f) {
-    MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
-    mMin = 0.0f;
-  }
-  if (mMax < 0.0f) {
-    MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
-    mMax = 0.0f;
-  }
-  if (mMin > mMax) {
-    MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
-    return RET_PARAM_INVALID;
-  }
-  if (mMin == mMax) {
-    if (mMin != 0.0f) {
-      MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
-      return RET_ERROR;
-    }
-    quantParam->inited = true;
-    quantParam->min = mMin;
-    quantParam->max = mMax;
-    quantParam->scale = 0.0f;
-    quantParam->zeroPoint = 0;
-    quantParam->narrowRange = narrowRange;
-    quantParam->numBits = num_bits;
-    return RET_OK;
-  }
-
-  auto quantMinFloat = static_cast<double>(quant_min);
-  auto quantMaxFloat = static_cast<double>(quant_max);
-  if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
-    MS_LOG(ERROR) << "divisor cannot be 0";
-    return RET_ERROR;
-  }
-  double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
-  if (fabs(scale) <= 0.0f) {
-    MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
-    return RET_ERROR;
-  }
-  const double zeroPointFromMin = quantMinFloat - mMin / scale;
-  int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
-  if (scale < SCALE_THREASHOLD) {
-    zeroPoint = 0;
-  }
-  // The zero point should always be in the range of quantized value,
-  // [qmin, qmax].
-  MS_ASSERT(zeroPoint >= quantMin);
-  MS_ASSERT(zeroPoint <= quantMax);
-  quantParam->inited = true;
-  quantParam->min = mMin;
-  quantParam->max = mMax;
-  quantParam->scale = scale;
-  quantParam->zeroPoint = zeroPoint;
-  quantParam->narrowRange = narrowRange;
-  quantParam->numBits = num_bits;
-
-  return RET_OK;
-}
-
 STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) {
   MS_ASSERT(quantParam != nullptr);
   if (mMin > 0.0f) {
@@ -999,26 +938,6 @@ STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_data
   return RET_OK;
 }
 
-void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
-                         bool channel_at_first, float *desired_max, float *desired_min) {
-  float min = FLT_MAX;
-  float max = -FLT_MAX;
-  // find min and max
-  for (int j = 0; j < one_filter_size; j++) {
-    auto index = j + i * one_filter_size;
-    if (!channel_at_first) {
-      index = j * channels + i;
-    }
-    if (index >= elem_count) {
-      MS_LOG(ERROR) << "over flow!";
-    }
-    min = std::min(min, raw_datas[index]);
-    max = std::max(max, raw_datas[index]);
-  }
-  *desired_max = max;
-  *desired_min = min;
-}
-
 int CalChannels(const ShapeVector &dims, int channel_cnt, bool *channel_at_first) {
   auto channels = dims[0];
   if (!(*channel_at_first)) {
diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h
index 6e0c04d3548..299324b9ce0 100644
--- a/mindspore/lite/tools/converter/quantizer/quantize_util.h
+++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h
@@ -43,6 +43,7 @@
 #include "src/lite_session.h"
 #include "tools/converter/graphdef_transform.h"
 #include "src/common/file_utils.h"
+#include "src/common/quant_utils.h"
 
 namespace mindspore::lite::quant {
 static constexpr size_t UINT8_QUANTIZATION = 8;
@@ -82,7 +83,7 @@ class QuantStrategy {
 
   bool CanConvOpQuantized(const CNodePtr &node) const;
   bool CanMulOpQuantized(const CNodePtr &node) const;
-  bool CanOpPostQuantized(AnfNodePtr &node) const;
+  bool CanOpPostQuantized(const AnfNodePtr &node) const;
   bool CanTensorQuantized(const AnfNodePtr &inputNode) const;
 
   size_t m_weight_size_;
@@ -100,9 +101,6 @@ constexpr int quant_param_size = 32 * 8;
 
 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
 
-STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
-                             int quant_min, int num_bits);
-
 STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange = false,
                              int numBits = UINT8_QUANTIZATION);
 
@@ -112,9 +110,6 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
 
 STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_datas, int new_size, TypeId new_data_type);
 
-void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
-                         bool channel_at_first, float *desired_max, float *desired_min);
-
 int CalChannels(const ShapeVector &dims, int channel_cnt, bool *channel_at_first);
 
 void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes, int index, bool *channel_at_first,
@@ -123,193 +118,10 @@ void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes,
 void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<int> &shapes, int index,
                        bool *channel_at_first, int *channel_cnt);
 
-template <typename T>
-T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
-  MS_ASSERT(quantParam != nullptr);
-  MS_ASSERT(quantParam->inited);
-  const auto scale = quantParam->scale;
-  const auto zeroPoint = quantParam->zeroPoint;
-  const auto numBit = quantParam->numBits;
-  const auto narrowRange = quantParam->narrowRange;
-  double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
-  const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
-  double minLimit;
-  if (narrowRange) {
-    minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
-  } else {
-    minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
-  }
-
-  return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
-    double tmp;
-    if (originData > maxLimit) {
-      tmp = maxLimit;
-    } else if (originData < minLimit) {
-      tmp = minLimit;
-    } else {
-      tmp = originData;
-    }
-    auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
-    return quantData;
-  }();
-}
-
-template <typename T>
-T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quant_max, int quant_min) {
-  MS_ASSERT(quantParam != nullptr);
-  MS_ASSERT(quantParam->inited);
-  const auto scale = quantParam.scale;
-  const int zeroPoint = quantParam.zeroPoint;
-  const auto narrowRange = quantParam.narrowRange;
-  const int maxLimit = quant_max;
-  const int minLimit = quant_min;
-  if (scale <= SCALE_THREASHOLD) {
-    return 0;
-  }
-  return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
-    auto quant_data = std::round(originData / scale + zeroPoint);
-    if (quant_data > maxLimit) {
-      quant_data = maxLimit;
-    } else if (quant_data < minLimit) {
-      quant_data = minLimit;
-    }
-    return static_cast<T>(quant_data);
-  }();
-}
-
 bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2);
 
 bool TensorQuantParamsInited(const schema::TensorT &tensor);
 
-template <typename T>
-STATUS DoPerChannelQuant(const tensor::TensorPtr &weight, const QuantType &quant_type,
-                         std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
-                         const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
-                         std::vector<float> *dequant_datas, TypeId quant_data_type, bool channel_at_first = true,
-                         int channel_cnt = -1) {
-  auto dims = weight->shape();
-  size_t elem_count = weight->DataSize();
-  auto *raw_datas = static_cast<float *>(weight->data_c());
-  auto channels = CalChannels(dims, channel_cnt, &channel_at_first);
-  if (channels == 0) {
-    MS_LOG(ERROR) << "channels is zero";
-    return RET_ERROR;
-  }
-  size_t one_filter_size = elem_count / channels;
-  bool do_quant = quant_param_size / (sizeof(float) * 8 - bit_num) < one_filter_size;
-  if (!do_quant && quant_type == QuantType_WeightQuant) {
-    MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << one_filter_size;
-    return RET_CONTINUE;
-  }
-  for (int i = 0; i < channels; i++) {
-    float min = FLT_MAX;
-    float max = -FLT_MAX;
-    GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
-    schema::QuantParamT quant_param;
-    STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
-    if (status != RET_OK) {
-      MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
-      return status;
-    }
-    // do quantization
-    double average_dequant = 0;
-    double average_raw = 0;
-    for (uint32_t j = 0; j < one_filter_size; j++) {
-      auto index = j + i * one_filter_size;
-      if (!channel_at_first) {
-        index = j * channels + i;
-      }
-      MS_ASSERT(index < elem_count);
-      float raw_data = raw_datas[index];
-      auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
-      (*quant_datas)[index] = quant_data;
-
-      if (quant_type == QuantType_WeightQuant) {
-        float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
-        (*dequant_datas)[index] = dequant_data;
-        average_dequant += dequant_data;
-        average_raw += raw_data;
-      }
-    }
-    if (quant_type == QuantType_WeightQuant && !k_means) {
-      // mean
-      average_dequant = average_dequant / one_filter_size;
-      average_raw = average_raw / one_filter_size;
-      // std
-      double variance_dequant = 0;
-      double variance_raw = 0;
-      for (uint32_t j = 0; j < one_filter_size; j++) {
-        auto index = j + i * one_filter_size;
-        if (!channel_at_first) {
-          index = j * channels + i;
-        }
-        MS_ASSERT(index < elem_count);
-        variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2);
-        variance_raw += std::pow(raw_datas[index] - average_raw, 2);
-      }
-      variance_dequant = std::sqrt(variance_dequant / one_filter_size);
-      variance_raw = std::sqrt(variance_raw / one_filter_size);
-      quant_param.varCorr = 1;
-      if (variance_raw != 0 && variance_dequant != 0) {
-        auto temp_var_corr = variance_raw / variance_dequant;
-        if (temp_var_corr > 0 && temp_var_corr < 10) {
-          quant_param.varCorr = temp_var_corr;
-        } else {
-          MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
-        }
-      }
-      quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
-    }
-    quant_params->emplace_back(quant_param);
-  }
-  auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T), quant_data_type);
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
-    return RET_ERROR;
-  }
-  return RET_OK;
-}
-
-template <typename T>
-STATUS DoPerLayerQuant(const tensor::TensorPtr &weight, const QuantType &quant_type,
-                       std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
-                       const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
-                       TypeId quant_data_type) {
-  auto dims = weight->shape();
-  size_t elem_count = weight->DataSize();
-  auto *raw_datas = static_cast<float *>(weight->data_c());
-  float min = FLT_MAX;
-  float max = -FLT_MIN;
-  for (uint32_t i = 0; i < elem_count; i++) {
-    // find max min
-    min = std::min(min, raw_datas[i]);
-    max = std::max(max, raw_datas[i]);
-  }
-
-  schema::QuantParamT quant_param;
-  if (!k_means) {
-    STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
-    if (status != RET_OK) {
-      MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
-      return status;
-    }
-  }
-  quant_params->emplace_back(quant_param);
-  // update data and datatype
-  for (uint32_t i = 0; i < elem_count; i++) {
-    float raw_data = raw_datas[i];
-    if (!k_means) {
-      auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
-      (*quant_datas)[i] = quant_data;
-    }
-  }
-  auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T), quant_data_type);
-  if (status != RET_OK) {
-    MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
-    return RET_ERROR;
-  }
-  return RET_OK;
-}
 template <typename T>
 STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const std::vector<T> &quant_datas) {
   if (bit_num != 8 && bit_num != 16) {
@@ -363,15 +175,19 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
   }
 
   std::vector<T> quant_data(elem_count);
-  std::vector<float> dequant_datas(elem_count);
   int ret = RET_OK;
   if (per_channel) {
     bool channel_at_first = true;
     int channel_cnt = -1;
     CalQuantAssitInfo(primitive, dims, index, &channel_at_first, &channel_cnt);
-    // channel at first
-    ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data,
-                               &dequant_datas, quant_data_type, channel_at_first, channel_cnt);
+    auto channels = CalChannels(dims, channel_cnt, &channel_at_first);
+    if (channels == 0) {
+      MS_LOG(ERROR) << "channels is zero";
+      return RET_ERROR;
+    }
+    ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(),
+                               static_cast<mindspore::schema::QuantType>(quant_type), &quant_params, quant_max,
+                               quant_min, bit_num, k_means, &quant_data, channels, channel_at_first);
     if (ret == RET_CONTINUE) {
       return ret;
     } else if (ret != RET_OK) {
@@ -379,13 +195,18 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
       return ret;
     }
   } else {
-    ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data,
-                             quant_data_type);
+    ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
+                             quant_min, bit_num, k_means, &quant_data);
     if (ret != RET_OK) {
       MS_LOG(ERROR) << "Do per layer quant failed.";
       return ret;
     }
   }
+  auto status = UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
+  if (status != RET_OK) {
+    MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
+    return RET_ERROR;
+  }
 
 #ifdef HUFFMAN_ENCODE
   auto huffman_encode = std::make_unique<lite::HuffmanEncode>();