tod more ops, improve kernel, and support ci with cloud
This commit is contained in:
parent
c9381e0448
commit
609fd478ef
|
@ -71,7 +71,7 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
|
|||
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict scale, int size, int total_size, int ch,
|
||||
const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) {
|
||||
float N = (float)total_size;
|
||||
const float N = (float)total_size;
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int c = 0; c < ch; c++) {
|
||||
// dx_2
|
||||
|
|
|
@ -64,7 +64,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin
|
|||
#ifdef ENABLE_ARM
|
||||
float *out_vec = out + (xw + in_w * xh) * channel + ic;
|
||||
float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic);
|
||||
float32x4_t outs = vaddq_s32(outr, delta);
|
||||
float32x4_t outs = vaddq_f32(outr, delta);
|
||||
vst1q_f32(out_vec, outs);
|
||||
#else
|
||||
|
||||
|
@ -94,7 +94,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin
|
|||
#ifdef ENABLE_ARM
|
||||
static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) {
|
||||
uint32x4_t res = vcgtq_f32(in, *max);
|
||||
uint32x4_t m_index = vbslq_f32(res, index, prev_index);
|
||||
int32x4_t m_index = vbslq_s32(res, index, prev_index);
|
||||
*max = vbslq_f32(res, in, *max);
|
||||
return m_index;
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p
|
|||
int kw_s = MSMAX(0, over_w);
|
||||
int kw_e = MSMIN(win_w, in_w + over_w);
|
||||
int ic = 0;
|
||||
for (; ic < channel - 4; ic += 4) {
|
||||
for (; ic < (channel & ~3); ic += 4) {
|
||||
int idx = (yw + yh * output_w) * channel + ic;
|
||||
#ifdef ENABLE_ARM
|
||||
uint32x4_t max_idx = vdupq_n_u32(0);
|
||||
|
@ -170,9 +170,8 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p
|
|||
float delta = dyPtr[idx];
|
||||
for (int kh = kh_s; kh < kh_e; kh++) {
|
||||
int xh = yh * stride_h + kh - pad_h;
|
||||
int loop = kw_e - kw_s;
|
||||
for (int kw = 0; kw < loop; kw++) {
|
||||
int xw = yw * stride_w + kw + kw_s - pad_w;
|
||||
for (int kw = kw_e; kw < kw_s; kw++) {
|
||||
int xw = yw * stride_w + kw - pad_w;
|
||||
int val_idx = (xw + in_w * xh) * channel + ic;
|
||||
float val = inPtr[val_idx];
|
||||
if (val > max_val) {
|
||||
|
|
|
@ -21,7 +21,31 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Abs::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Abs;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Abs) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::AbsT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_ABS_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_ABS_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -32,10 +32,11 @@ class Abs : public ArithmeticSelf {
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Abs, ArithmeticSelf);
|
||||
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ABS_H_
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_ABS_H_
|
||||
|
|
|
@ -22,7 +22,31 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Cos::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Cos;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Cos) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::CosT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_COS_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_COS_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_COS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
@ -30,6 +30,7 @@ class Cos : public ArithmeticSelf {
|
|||
~Cos() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
|
@ -37,4 +38,4 @@ class Cos : public ArithmeticSelf {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_COS_H_
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_COS_H_
|
||||
|
|
|
@ -529,6 +529,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
const auto &op_type = prim.name();
|
||||
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") {
|
||||
return NewPrimitiveC<Activation>(prim, inputs, quantType);
|
||||
} else if (op_type == "Abs") {
|
||||
return NewPrimitiveC<Abs>(prim, inputs, quantType);
|
||||
} else if (op_type == "AddN") {
|
||||
return NewPrimitiveC<AddN>(prim, inputs, quantType);
|
||||
} else if (op_type == "BatchNorm") {
|
||||
|
@ -539,6 +541,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<Concat>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2D") {
|
||||
return NewPrimitiveC<Conv2D>(prim, inputs, quantType);
|
||||
} else if (op_type == "Cos") {
|
||||
return NewPrimitiveC<Cos>(prim, inputs, quantType);
|
||||
} else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") {
|
||||
return NewPrimitiveC<DepthwiseConv2D>(prim, inputs, quantType);
|
||||
} else if (op_type == "Dequant") {
|
||||
|
@ -559,6 +563,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<Quant>(prim, inputs, quantType);
|
||||
} else if (op_type == "RealDiv") {
|
||||
return NewPrimitiveC<RealDiv>(prim, inputs, quantType);
|
||||
} else if (op_type == "Reciprocal") {
|
||||
return NewPrimitiveC<Reciprocal>(prim, inputs, quantType);
|
||||
} else if (op_type == "ReduceMax") {
|
||||
return NewPrimitiveC<Reduce>(prim, inputs, quantType);
|
||||
} else if (op_type == "ReduceMean") {
|
||||
|
@ -573,6 +579,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<Reduce>(prim, inputs, quantType);
|
||||
} else if (op_type == "Reshape") {
|
||||
return NewPrimitiveC<Reshape>(prim, inputs, quantType);
|
||||
} else if (op_type == "Sin") {
|
||||
return NewPrimitiveC<Sin>(prim, inputs, quantType);
|
||||
} else if (op_type == "Slice") {
|
||||
return NewPrimitiveC<Slice>(prim, inputs, quantType);
|
||||
} else if (op_type == "Squeeze") {
|
||||
|
|
|
@ -22,7 +22,31 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Reciprocal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Reciprocal;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Reciprocal) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::ReciprocalT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Reciprocal>(primitive);
|
||||
}
|
||||
|
|
|
@ -14,10 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_
|
||||
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -28,6 +31,7 @@ class Reciprocal : public ArithmeticSelf {
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf);
|
||||
explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
@ -42,4 +46,4 @@ class Reciprocal : public ArithmeticSelf {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_
|
||||
|
|
|
@ -23,6 +23,29 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Sin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Sin;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Sin) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::SinT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
int Sin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_SIN_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_SIN_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_SIN_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_SIN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
@ -32,6 +32,7 @@ class Sin : public ArithmeticSelf {
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Sin, ArithmeticSelf);
|
||||
explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
|
@ -39,4 +40,4 @@ class Sin : public ArithmeticSelf {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_SIN_H_
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_SIN_H_
|
||||
|
|
|
@ -24,6 +24,7 @@ using mindspore::lite::RET_ERROR;
|
|||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_ExpandDims;
|
||||
using mindspore::schema::PrimitiveType_Flatten;
|
||||
using mindspore::schema::PrimitiveType_FlattenGrad;
|
||||
using mindspore::schema::PrimitiveType_Reshape;
|
||||
using mindspore::schema::PrimitiveType_Squeeze;
|
||||
using mindspore::schema::PrimitiveType_Unsqueeze;
|
||||
|
@ -77,6 +78,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, LiteKernelCreator<Re
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reshape, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Flatten, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Flatten, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FlattenGrad, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>)
|
||||
|
|
|
@ -43,7 +43,7 @@ int ArithmeticSelfGradCPUKernel::Init() {
|
|||
self_grad_operation_ = ElementDiv;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport type: " << type;
|
||||
MS_LOG(ERROR) << "Unsupported type: " << type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -360,14 +360,14 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu
|
|||
}
|
||||
auto ret = session->Init(context);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init sesssion failed";
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ret = session->CompileTrainGraph(model);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Compiling Train Graph sesssion failed";
|
||||
MS_LOG(ERROR) << "Compiling Train Graph session failed";
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mini_alexnet
|
||||
mobilenetv1
|
||||
mobilenetv2
|
||||
mobilenetv3
|
||||
#mobilenetv3
|
||||
lenet
|
||||
effnet
|
||||
effnet_tune
|
||||
|
|
|
@ -71,6 +71,7 @@ function Run_Converter() {
|
|||
# Run on x86 platform:
|
||||
function Run_x86() {
|
||||
# Run mindspore converted train models:
|
||||
fail=0
|
||||
while read line; do
|
||||
model_name=${line}
|
||||
if [[ $model_name == \#* ]]; then
|
||||
|
@ -80,21 +81,23 @@ function Run_x86() {
|
|||
echo ${model_name}'_train' >> "${run_x86_log_file}"
|
||||
echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}"
|
||||
cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1
|
||||
echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_outputs.bin --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}"
|
||||
echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}"
|
||||
echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}"
|
||||
LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib \
|
||||
${run_valgrind}./benchmark_train/benchmark_train \
|
||||
--modelFile=${ms_models_path}/${model_name}_train.ms \
|
||||
--inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \
|
||||
--expectedDataFile=${train_io_path}/${model_name}_outputs.bin \
|
||||
--expectedDataFile=${train_io_path}/${model_name}_output \
|
||||
--exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \
|
||||
--epochs=${epoch_num} --numThreads=${threads}
|
||||
if [ $? = 0 ]; then
|
||||
run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
|
||||
else
|
||||
run_result='x86: '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}
|
||||
fail=1
|
||||
fi
|
||||
done < ${models_mindspore_train_config}
|
||||
return ${fail}
|
||||
}
|
||||
|
||||
# Run on arm platform:
|
||||
|
@ -157,7 +160,7 @@ function Run_arm() {
|
|||
echo 'chmod 777 benchmark_train' >> ${adb_cmd_file}
|
||||
|
||||
adb -s ${device_id} shell < ${adb_cmd_file}
|
||||
|
||||
fail=0
|
||||
# Run mindir converted train models:
|
||||
while read line; do
|
||||
model_name=${line}
|
||||
|
@ -167,7 +170,7 @@ function Run_arm() {
|
|||
|
||||
# run benchmark_train test without clib data
|
||||
echo ${model_name}'_train' >> "${run_arm_log_file}"
|
||||
adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_outputs.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file}
|
||||
adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_output*.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file}
|
||||
echo 'cd /data/local/tmp/benchmark_train_test' > ${adb_cmd_run_file}
|
||||
echo 'chmod 777 benchmark_train' >> ${adb_cmd_run_file}
|
||||
if [ "$1" == arm64 ]; then
|
||||
|
@ -182,7 +185,7 @@ function Run_arm() {
|
|||
--epochs=${epoch_num} \
|
||||
--modelFile=${model_name}_train.ms \
|
||||
--inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \
|
||||
--expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \
|
||||
--expectedDataFile=${tmp_dir}/${model_name}_output \
|
||||
--exportFile=${tmp_dir}/${model_name}_train_exported.ms \
|
||||
--numThreads=${threads}
|
||||
ENDM
|
||||
|
@ -195,8 +198,11 @@ ENDM
|
|||
run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
|
||||
else
|
||||
run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file};
|
||||
fail=1
|
||||
fi
|
||||
|
||||
done < ${models_mindspore_train_config}
|
||||
return ${fail}
|
||||
}
|
||||
|
||||
# Print start msg before run testcase
|
||||
|
|
|
@ -59,6 +59,41 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
|||
}
|
||||
}
|
||||
|
||||
void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
|
||||
bool hasDepend = false;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
|
||||
inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
AnfNodePtr inputNode = cnode->input(i);
|
||||
if (!inputNode->isa<CNode>()) {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
continue;
|
||||
}
|
||||
auto dependNode = utils::cast<CNodePtr>(inputNode);
|
||||
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
|
||||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
|
||||
hasDepend = true;
|
||||
bool maskOut = (dependNode->inputs().size() == 3);
|
||||
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
|
||||
AnfNodePtr dependInputNode = dependNode->input(j);
|
||||
if (dependInputNode->isa<CNode>()) {
|
||||
inputs.emplace_back(dependInputNode);
|
||||
if (maskOut) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
if (hasDepend) {
|
||||
cnode->set_inputs(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||
const std::shared_ptr<PrimitiveC> &primitive,
|
||||
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
||||
|
@ -251,8 +286,17 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
RemoveIfMakeTuple(cnode);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
RemoveIfDepend(cnode);
|
||||
#endif
|
||||
|
||||
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
|
||||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
|
||||
#endif
|
||||
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
#ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
@ -41,6 +41,7 @@ class AnfExporter {
|
|||
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *fb_node);
|
||||
static void RemoveIfMakeTuple(const CNodePtr &cnode);
|
||||
static void RemoveIfDepend(const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode);
|
||||
|
@ -97,4 +98,4 @@ class AnfExporter {
|
|||
// and clear.
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
|
|
|
@ -32,6 +32,42 @@ static const char *DELIM_COLON = ":";
|
|||
static const char *DELIM_COMMA = ",";
|
||||
static const char *DELIM_SLASH = "/";
|
||||
|
||||
namespace {
|
||||
float *ReadFileBuf(const char *file, size_t *size) {
|
||||
if (file == nullptr) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(size != nullptr);
|
||||
std::string real_path = RealPath(file);
|
||||
std::ifstream ifs(real_path);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << " open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
*size = ifs.tellg();
|
||||
std::unique_ptr<float[]> buf((new (std::nothrow) float[*size / sizeof(float) + 1]));
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buf.get()), *size);
|
||||
ifs.close();
|
||||
|
||||
return buf.release();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int NetTrain::GenerateRandomData(size_t size, void *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
char *casted_data = static_cast<char *>(data);
|
||||
|
@ -113,82 +149,34 @@ int NetTrain::ReadInputFile() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
// calibData is FP32
|
||||
int NetTrain::ReadCalibData() {
|
||||
const char *calib_data_path = flags_->data_file_.c_str();
|
||||
// read calib data
|
||||
std::ifstream in_file(calib_data_path);
|
||||
if (!in_file.good()) {
|
||||
std::cerr << "file: " << calib_data_path << " is not exist" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!in_file.is_open()) {
|
||||
std::cerr << "file: " << calib_data_path << " open failed" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << calib_data_path << " open failed";
|
||||
in_file.close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
|
||||
MS_LOG(INFO) << "Start reading calibData file";
|
||||
std::string tensor_name;
|
||||
while (!in_file.eof()) {
|
||||
getline(in_file, line);
|
||||
std::stringstream string_line1(line);
|
||||
size_t dim = 0;
|
||||
string_line1 >> tensor_name >> dim;
|
||||
std::vector<size_t> dims;
|
||||
size_t shape_size = 1;
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
size_t tmp_dim;
|
||||
string_line1 >> tmp_dim;
|
||||
dims.push_back(tmp_dim);
|
||||
shape_size *= tmp_dim;
|
||||
}
|
||||
|
||||
getline(in_file, line);
|
||||
std::stringstream string_line2(line);
|
||||
std::vector<float> tensor_data;
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
float tmp_data;
|
||||
string_line2 >> tmp_data;
|
||||
tensor_data.push_back(tmp_data);
|
||||
}
|
||||
auto *check_tensor = new CheckTensor(dims, tensor_data);
|
||||
this->data_.insert(std::make_pair(tensor_name, check_tensor));
|
||||
}
|
||||
in_file.close();
|
||||
MS_LOG(INFO) << "Finish reading calibData file";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NetTrain::CompareOutput() {
|
||||
std::cout << "================ Comparing Output data ================" << std::endl;
|
||||
float total_bias = 0;
|
||||
int total_size = 0;
|
||||
bool has_error = false;
|
||||
|
||||
for (const auto &calib_tensor : data_) {
|
||||
std::string node_or_tensor_name = calib_tensor.first;
|
||||
auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
|
||||
mindspore::tensor::MSTensor *tensor = nullptr;
|
||||
if (tensors.empty() || tensors.size() != 1) {
|
||||
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
|
||||
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
|
||||
tensor = session_->GetOutputByTensorName(node_or_tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot find output tensor " << node_or_tensor_name << ", get model output failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
tensor = tensors.front();
|
||||
}
|
||||
MS_ASSERT(tensor->MutableData() != nullptr);
|
||||
auto tensors_list = session_->GetOutputs();
|
||||
if (tensors_list.empty()) {
|
||||
MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
mindspore::tensor::MSTensor *tensor = nullptr;
|
||||
int i = 1;
|
||||
for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
|
||||
tensor = session_->GetOutputByTensorName(it->first);
|
||||
auto outputs = tensor->MutableData();
|
||||
float bias = CompareData<float>(node_or_tensor_name, tensor->shape(), reinterpret_cast<float *>(outputs));
|
||||
size_t size;
|
||||
std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
|
||||
auto *bin_buf = ReadFileBuf(output_file.c_str(), &size);
|
||||
if (bin_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "ReadFile return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (size != tensor->Size()) {
|
||||
MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size()
|
||||
<< ", read size: " << size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
float bias = CompareData<float>(bin_buf, tensor->ElementsNum(), reinterpret_cast<float *>(outputs));
|
||||
if (bias >= 0) {
|
||||
total_bias += bias;
|
||||
total_size++;
|
||||
|
@ -196,6 +184,8 @@ int NetTrain::CompareOutput() {
|
|||
has_error = true;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
delete bin_buf;
|
||||
}
|
||||
|
||||
if (!has_error) {
|
||||
|
@ -206,7 +196,8 @@ int NetTrain::CompareOutput() {
|
|||
mean_bias = 0;
|
||||
}
|
||||
|
||||
std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" << std::endl;
|
||||
std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
|
||||
<< " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
|
||||
std::cout << "=======================================================" << std::endl << std::endl;
|
||||
|
||||
if (mean_bias > this->flags_->accuracy_threshold_) {
|
||||
|
@ -297,13 +288,6 @@ int NetTrain::MarkAccuracy() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ReadCalibData();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read calib data error " << status;
|
||||
std::cerr << "Read calib data error " << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CompareOutput();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Compare output error " << status;
|
||||
|
@ -454,7 +438,7 @@ int NetTrain::RunNetTrain() {
|
|||
std::cout << "Run SaveToFile error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// delete session_;
|
||||
status = RunExportedNet();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run Exported model error: " << status;
|
||||
|
|
|
@ -116,8 +116,6 @@ class MS_API NetTrain {
|
|||
|
||||
int ReadInputFile();
|
||||
|
||||
int ReadCalibData();
|
||||
|
||||
int CompareOutput();
|
||||
|
||||
int InitCallbackParameter();
|
||||
|
@ -140,78 +138,49 @@ class MS_API NetTrain {
|
|||
|
||||
// tensorData need to be converter first
|
||||
template <typename T>
|
||||
float CompareData(const std::string &nodeName, std::vector<int> msShape, T *msTensorData) {
|
||||
auto iter = this->data_.find(nodeName);
|
||||
if (iter != this->data_.end()) {
|
||||
std::vector<size_t> castedMSShape;
|
||||
size_t shapeSize = 1;
|
||||
for (int64_t dim : msShape) {
|
||||
castedMSShape.push_back(size_t(dim));
|
||||
shapeSize *= dim;
|
||||
float CompareData(const float *refOutput, int size, T *msTensorData) {
|
||||
size_t errorCount = 0;
|
||||
float meanError = 0;
|
||||
std::cout << "Data of model output: ";
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (j < 50) {
|
||||
std::cout << static_cast<float>(msTensorData[j]) << " ";
|
||||
}
|
||||
|
||||
CheckTensor *calibTensor = iter->second;
|
||||
if (calibTensor->shape != castedMSShape) {
|
||||
std::ostringstream oss;
|
||||
oss << "Shape of mslite output(";
|
||||
for (auto dim : castedMSShape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") and shape source model output(";
|
||||
for (auto dim : calibTensor->shape) {
|
||||
oss << dim << ",";
|
||||
}
|
||||
oss << ") are different";
|
||||
std::cerr << oss.str() << std::endl;
|
||||
MS_LOG(ERROR) << oss.str().c_str();
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t errorCount = 0;
|
||||
float meanError = 0;
|
||||
std::cout << "Data of node " << nodeName << " : ";
|
||||
for (size_t j = 0; j < shapeSize; j++) {
|
||||
if (j < 50) {
|
||||
std::cout << static_cast<float>(msTensorData[j]) << " ";
|
||||
}
|
||||
|
||||
if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
|
||||
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
|
||||
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
|
||||
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
|
||||
if (absoluteError > tolerance) {
|
||||
if (fabs(calibTensor->data.at(j)) == 0) {
|
||||
if (absoluteError > 1e-5) {
|
||||
meanError += absoluteError;
|
||||
errorCount++;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// just assume that atol = rtol
|
||||
meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN);
|
||||
auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]);
|
||||
auto absoluteError = std::fabs(msTensorData[j] - refOutput[j]);
|
||||
if (absoluteError > tolerance) {
|
||||
if (fabs(refOutput[j]) == 0) {
|
||||
if (absoluteError > 1e-5) {
|
||||
meanError += absoluteError;
|
||||
errorCount++;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// just assume that atol = rtol
|
||||
meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN);
|
||||
errorCount++;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (meanError > 0.0f) {
|
||||
meanError /= errorCount;
|
||||
}
|
||||
|
||||
if (meanError <= 0.0000001) {
|
||||
std::cout << "Mean bias of node/tensor " << nodeName << " : 0%" << std::endl;
|
||||
} else {
|
||||
std::cout << "Mean bias of node/tensor " << nodeName << " : " << meanError * 100 << "%" << std::endl;
|
||||
}
|
||||
return meanError;
|
||||
} else {
|
||||
MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (meanError > 0.0f) {
|
||||
meanError /= errorCount;
|
||||
}
|
||||
|
||||
if (meanError <= 0.0000001) {
|
||||
std::cout << "Mean bias of tensor: 0%" << std::endl;
|
||||
} else {
|
||||
std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl;
|
||||
}
|
||||
return meanError;
|
||||
}
|
||||
|
||||
int MarkPerformance();
|
||||
|
|
|
@ -144,8 +144,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
|
|||
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
|
||||
const converter::Flags *config) {
|
||||
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
|
||||
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
|
||||
if (!config->trainModel) {
|
||||
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
|
||||
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
|
||||
inne_context_ptr->Init();
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));
|
||||
|
|
Loading…
Reference in New Issue