Make code examples more robust to CI

This commit is contained in:
Emir Haleva 2021-07-20 14:34:13 +03:00
parent 31a4c3116e
commit ec1be107ad
3 changed files with 9 additions and 8 deletions

View File

@ -38,6 +38,7 @@
using mindspore::dataset::Dataset;
using mindspore::dataset::Mnist;
using mindspore::dataset::SequentialSampler;
using mindspore::dataset::TensorOperation;
using mindspore::dataset::transforms::TypeCast;
using mindspore::dataset::vision::Normalize;
@ -185,7 +186,7 @@ float NetRunner::CalculateAccuracy(int max_tests) {
}
int NetRunner::InitDB() {
train_ds_ = Mnist(data_dir_ + "/train", "all");
train_ds_ = Mnist(data_dir_ + "/train", "all", std::make_shared<SequentialSampler>(0, 0));
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
Resize resize({h_, w_});

View File

@ -43,6 +43,7 @@ using mindspore::TrainCallBack;
using mindspore::TrainCallBackData;
using mindspore::dataset::Dataset;
using mindspore::dataset::Mnist;
using mindspore::dataset::SequentialSampler;
using mindspore::dataset::TensorOperation;
using mindspore::dataset::transforms::TypeCast;
using mindspore::dataset::vision::Normalize;
@ -161,16 +162,13 @@ void NetRunner::InitAndFigureInputs() {
MS_ASSERT(status != mindspore::kSuccess);
}
// if (verbose_) {
// loop_->SetKernelCallBack(nullptr, after_callback);
//}
acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics);
model_->InitMetrics({acc_metrics_.get()});
auto inputs = model_->GetInputs();
MS_ASSERT(inputs.size() >= 1);
auto nhwc_input_dims = inputs.at(0).Shape();
// MS_ASSERT(nhwc_input_dims.size() == kNCHWDims);
batch_size_ = nhwc_input_dims.at(0);
h_ = nhwc_input_dims.at(1);
w_ = nhwc_input_dims.at(kNCHWCDim);
@ -186,8 +184,6 @@ float NetRunner::CalculateAccuracy(int max_tests) {
test_ds_ = test_ds_->Map({&typecast}, {"label"});
test_ds_ = test_ds_->Batch(batch_size_, true);
// Rescaler rescale(kScalePoint);
model_->Evaluate(test_ds_, {});
std::cout << "Accuracy is " << acc_metrics_->Eval() << std::endl;
@ -195,7 +191,7 @@ float NetRunner::CalculateAccuracy(int max_tests) {
}
int NetRunner::InitDB() {
train_ds_ = Mnist(data_dir_ + "/train", "all");
train_ds_ = Mnist(data_dir_ + "/train", "all", std::make_shared<SequentialSampler>(0, 0));
TypeCast typecast_f(mindspore::DataType::kNumberTypeFloat32);
Resize resize({h_, w_});

View File

@ -379,8 +379,10 @@ function Run_CodeExamples() {
accurate=$(tail -20 ${run_code_examples_log_file} | awk 'NF==3 && /Accuracy is/ { sum += $3} END { print (sum > 1.9) }')
if [ $accurate -eq 1 ]; then
echo "Unified API Trained and reached accuracy" >> ${run_code_examples_log_file}
echo 'code_examples: unified_api pass' >> ${run_benchmark_train_result_file}
else
echo "Unified API demo failure" >> ${run_code_examples_log_file}
echo 'code_examples: unified_api failed' >> ${run_benchmark_train_result_file}
fail=1
fi
rm -rf package*/dataset
@ -393,8 +395,10 @@ function Run_CodeExamples() {
accurate=$(tail -10 ${run_code_examples_log_file} | awk 'NF==3 && /Accuracy is/ { sum += $3} END { print (sum > 1.9) }')
if [ $accurate -eq 1 ]; then
echo "Lenet Trained and reached accuracy" >> ${run_code_examples_log_file}
echo 'code_examples: train_lenet pass' >> ${run_benchmark_train_result_file}
else
echo "Train Lenet demo failure" >> ${run_code_examples_log_file}
echo 'code_examples: train_lenet failed' >> ${run_benchmark_train_result_file}
fail=1
fi
rm -rf package*/dataset