[MSLITE] mix fp16/fp32

This commit is contained in:
ling 2021-08-18 12:02:51 +08:00
parent fd06532b59
commit 35ec7ab07d
16 changed files with 362 additions and 49 deletions

View File

@ -18,17 +18,16 @@
namespace mindspore {
namespace lite {
std::map<std::string, std::string> GetSectionInfoFromConfigFile(const std::string &file,
const std::string &section_name) {
std::map<std::string, std::string> section_info;
int GetSectionInfoFromConfigFile(const std::string &file, const std::string &section_name,
std::map<std::string, std::string> *section_info) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return section_info;
return RET_ERROR;
}
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
return section_info;
return RET_ERROR;
}
#ifdef _WIN32
@ -38,16 +37,16 @@ std::map<std::string, std::string> GetSectionInfoFromConfigFile(const std::strin
#endif
if (real_path == nullptr || strlen(real_path) == 0) {
MS_LOG(ERROR) << "file path is not valid : " << file;
return section_info;
return RET_ERROR;
}
std::ifstream ifs(resolved_path.get());
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return section_info;
return RET_ERROR;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << real_path << "open failed";
return section_info;
return RET_ERROR;
}
std::string line;
@ -81,12 +80,41 @@ std::map<std::string, std::string> GetSectionInfoFromConfigFile(const std::strin
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
section_info.insert(std::make_pair(key, value));
section_info->insert(std::make_pair(key, value));
}
}
ifs.close();
return section_info;
return RET_OK;
}
void ParserExecutionPlan(const std::map<std::string, std::string> *config_infos,
std::map<std::string, TypeId> *data_type_plan) {
for (auto info : *config_infos) {
std::string op_name = info.first;
std::string value = info.second;
if (value[0] == '"') {
value = value.substr(1, value.length() - 2);
}
auto index = value.find(':');
if (index == std::string::npos) {
continue;
}
auto data_type_key = value.substr(0, index);
auto data_type_value = value.substr(index + 1);
if (data_type_key != "data_type") {
continue;
}
TypeId type_id = kTypeUnknown;
if (data_type_value == "float32") {
type_id = kNumberTypeFloat32;
} else if (data_type_value == "float16") {
type_id = kNumberTypeFloat16;
} else {
continue;
}
data_type_plan->insert(std::make_pair(op_name, type_id));
}
}
} // namespace lite
} // namespace mindspore

View File

@ -30,13 +30,19 @@
#include <utility>
#include "src/common/utils.h"
#include "src/common/log_adapter.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
constexpr int MAX_CONFIG_FILE_LENGTH = 1024;
#define CONFIG_FILE_EXECUTION_PLAN "execution_plan"
int GetSectionInfoFromConfigFile(const std::string &file, const std::string &section_name,
std::map<std::string, std::string> *section_info);
void ParserExecutionPlan(const std::map<std::string, std::string> *config_infos,
std::map<std::string, TypeId> *data_type_plan);
std::map<std::string, std::string> GetSectionInfoFromConfigFile(const std::string &file,
const std::string &section_name);
} // namespace lite
} // namespace mindspore

View File

@ -15,6 +15,7 @@
*/
#include "include/api/model.h"
#include <mutex>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/callback/callback.h"
@ -25,14 +26,19 @@
#include "src/common/log_adapter.h"
namespace mindspore {
std::mutex g_impl_init_lock;
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
if (ret != kSuccess) {
return ret;
@ -42,11 +48,15 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
Status ret = impl_->Build(model_path, model_type, model_context);
if (ret != kSuccess) {
return ret;
@ -57,16 +67,15 @@ Status Model::Build(const std::string &model_path, ModelType model_type, const s
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
const std::shared_ptr<TrainCfg> &train_cfg) {
std::stringstream err_msg;
if (impl_ != nullptr) {
MS_LOG(DEBUG) << "Model has been already built.";
return kSuccess;
}
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
err_msg << "Model implement is null.";
MS_LOG(ERROR) << err_msg.str();
return Status(kLiteNullptr, err_msg.str());
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
if (graph.GetGraph() == nullptr) {
err_msg << "Invalid null graph.";
MS_LOG(ERROR) << err_msg.str();
@ -161,6 +170,27 @@ std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_
return impl_->GetOutputsByNodeName(CharToString(node_name));
}
Status Model::LoadConfig(const std::string &config_path) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ != nullptr) {
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
return kLiteFileError;
}
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
auto ret = impl_->LoadConfig(config_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
return kLiteFileError;
}
return kSuccess;
}
Status Model::SetTrainMode(bool train) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";

View File

@ -29,6 +29,7 @@
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/common/file_utils.h"
#include "src/common/config_file.h"
namespace mindspore {
using mindspore::lite::RET_ERROR;
@ -183,6 +184,23 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
Status ModelImpl::LoadConfig(const std::string &config_path) {
std::map<std::string, std::string> config_info;
int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed.";
return kLiteFileError;
}
if (config_info.empty()) {
MS_LOG(WARNING) << "No valid info in config file.";
return kSuccess;
}
lite::ParserExecutionPlan(&config_info, &execution_plan_);
return kSuccess;
}
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
@ -462,6 +480,8 @@ session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context)
return nullptr;
}
session->InitExecutionConfig(&execution_plan_);
auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";

View File

@ -69,6 +69,7 @@ class ModelImpl {
session::LiteSession *CreateLiteSession(lite::InnerContext *context);
Status LoadConfig(const std::string &config_path);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
MSTensor GetInputByTensorName(const std::string &name);
@ -106,6 +107,7 @@ class ModelImpl {
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
std::map<std::string, TypeId> execution_plan_;
};
} // namespace mindspore

View File

@ -32,12 +32,7 @@ const constexpr int kMaxLiteContextDeviceNums = 2;
const constexpr int kMaxInnerContextDeviceNums = 3;
} // namespace
InnerContext::InnerContext(const Context *context) {
this->allocator = context->allocator;
this->thread_num_ = context->thread_num_;
this->enable_parallel_ = context->enable_parallel_;
this->affinity_core_list_ = context->affinity_core_list_;
SetContextDevice(context);
void InnerContext::InitDeviceFp16() {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
CpuInfo cpu_info;
device_and_pkg_support_fp16_ = cpu_info.ArmIsSupportFp16();
@ -46,6 +41,15 @@ InnerContext::InnerContext(const Context *context) {
#endif
}
InnerContext::InnerContext(const Context *context) {
this->allocator = context->allocator;
this->thread_num_ = context->thread_num_;
this->enable_parallel_ = context->enable_parallel_;
this->affinity_core_list_ = context->affinity_core_list_;
SetContextDevice(context);
InitDeviceFp16();
}
void InnerContext::SetContextDevice(const Context *context) {
MS_ASSERT(context->device_list_.size() <= kMaxLiteContextDeviceNums);

View File

@ -28,7 +28,7 @@
namespace mindspore::lite {
struct InnerContext : public Context {
public:
InnerContext() = default;
InnerContext() { InitDeviceFp16(); }
explicit InnerContext(const Context *context);
@ -77,6 +77,8 @@ struct InnerContext : public Context {
void SetContextDevice(const Context *context);
void InitDeviceFp16();
bool device_and_pkg_support_fp16_ = false;
ActorThreadPool *thread_pool_{nullptr};

View File

@ -509,7 +509,8 @@ int LiteSession::CompileGraph(Model *model) {
InitGraphInputTensors(model);
InitGraphOutputTensors(model);
// scheduler kernels
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, execution_plan_,
delegate_);
scheduler.SetupSchedulerCb(std::move(sched_cb_));
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {

View File

@ -21,6 +21,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <map>
#include <atomic>
#include "src/lite_kernel.h"
#include "include/ms_tensor.h"
@ -77,6 +78,8 @@ class LiteSession : public session::LiteSession {
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims) override;
void InitExecutionConfig(std::map<std::string, TypeId> *config) { execution_plan_ = config; }
void set_model(Model *model) { this->model_ = model; }
const std::vector<kernel::LiteKernel *> &get_kernels() const { return this->kernels_; }
@ -158,6 +161,7 @@ class LiteSession : public session::LiteSession {
#endif
std::unique_ptr<SchedulerCb> sched_cb_;
std::shared_ptr<Delegate> delegate_ = nullptr;
std::map<std::string, TypeId> *execution_plan_ = nullptr;
};
} // namespace lite
} // namespace mindspore

View File

@ -718,6 +718,17 @@ inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tens
}
} // namespace
void Scheduler::ResetByExecutionPlan(std::string node_name, TypeId *data_type) {
if (execution_plan_ == nullptr) {
return;
}
auto iter = execution_plan_->find(node_name);
if (iter != execution_plan_->end()) {
*data_type = iter->second;
}
return;
}
int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel) {
@ -1192,6 +1203,9 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src
std::vector<Tensor *> outputs;
MS_ASSERT(src_node != nullptr);
FindNodeInoutTensors(*src_node, &inputs, &outputs);
ResetByExecutionPlan(src_node->name_, &prefer_data_type);
auto *kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type);
op_parameters_[src_node->output_indices_.at(0)] = nullptr;
if (kernel == nullptr) {
@ -1276,7 +1290,7 @@ int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kern
kernel = ScheduleNodeToKernel(node, prefer_data_type);
}
if (kernel == nullptr || ret != RET_OK) {
MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_
MS_LOG(ERROR) << "schedule node return nullptr, name: " << node->name_
<< ", type: " << GetPrimitiveTypeName(primitive, schema_version_);
return RET_ERROR;
}

View File

@ -24,6 +24,7 @@
#include <deque>
#include <unordered_map>
#include <set>
#include <string>
#include "src/sub_graph_kernel.h"
#include "src/inner_context.h"
#include "include/model.h"
@ -41,7 +42,7 @@ class Scheduler {
Scheduler(const InnerContext *ctx, const mindspore::Context *ms_ctx, Model *src_model,
std::vector<Tensor *> *src_tensors, const std::vector<Tensor *> &input_tensors,
const std::vector<Tensor *> &output_tensors, bool is_train_session,
std::shared_ptr<Delegate> delegate = nullptr)
std::map<std::string, TypeId> *executions, std::shared_ptr<Delegate> delegate = nullptr)
: context_(ctx),
ms_context_(ms_ctx),
src_model_(src_model),
@ -49,7 +50,8 @@ class Scheduler {
inputs_(input_tensors),
outputs_(output_tensors),
is_train_session_(is_train_session),
delegate_(delegate) {}
delegate_(delegate),
execution_plan_(executions) {}
~Scheduler() = default;
int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels);
void SetupSchedulerCb(std::unique_ptr<SchedulerCb> cb) { sched_cb_ = std::move(cb); }
@ -72,6 +74,8 @@ class Scheduler {
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel);
int CheckCpuValid(std::vector<kernel::LiteKernel *> *dst_kernels);
void ResetByExecutionPlan(std::string node_name, TypeId *data_type);
#ifdef GPU_OPENCL
int FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
@ -148,6 +152,7 @@ class Scheduler {
std::set<lite::Model::Node *> partial_cnode_inferred_{};
#endif
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
std::map<std::string, TypeId> *execution_plan_ = nullptr;
};
} // namespace mindspore::lite

View File

@ -329,6 +329,7 @@ set(TEST_SRC
${TEST_CASE_KERNEL_SRC}
${TEST_DIR}/main.cc
${TEST_DIR}/common/common_test.cc
${TEST_DIR}/st/mix_data_type_test.cc
${TEST_DIR}/ut/src/infer_test.cc
${TEST_DIR}/ut/src/utils_test.cc
${TEST_DIR}/ut/src/dynamic_library_loader_test.cc

View File

@ -144,5 +144,6 @@ MultipleDeviceTest.NewApi7
MultipleDeviceTest.NewApi8
MindrtRuntimeTest.Runtime
MindrtRuntimeTest.RuntimeFp16
MixDataTypeTest.mix1
SchedulerTest.TestScheduleInt32OpToFp16Subgraph

View File

@ -91,3 +91,6 @@ echo 'user set output tensors st test'
echo 'runtime pass'
./lite-test --gtest_filter="RuntimePass.*"
echo 'Runtime config file test'
./lite-test --gtest_filter="MixDataTypeTest.Config1"

View File

@ -18,6 +18,14 @@
#include "common/common_test.h"
#include "include/errorcode.h"
#include "src/common/config_file.h"
#include "schema/inner/model_generated.h"
#include "schema/inner/ops_generated.h"
#include "schema/ops_generated.h"
#include "schema/model_generated.h"
#include "src/lite_kernel.h"
#include "src/lite_session.h"
#include "include/api/model.h"
#include "src/cxx_api/model/model_impl.h"
namespace mindspore {
class MixDataTypeTest : public mindspore::CommonTest {
@ -26,24 +34,208 @@ class MixDataTypeTest : public mindspore::CommonTest {
};
TEST_F(MixDataTypeTest, Config1) {
auto ret = system("echo [execution_plan] > MixDataTypeTestConfig");
auto ret = system("echo [onther_plan1] > MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo op1=data_type:fp32 >> MixDataTypeTestConfig");
ret = system("echo op1=data_type:float16 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo \"op2=\\\"data_type:fp16\\\"\" >> MixDataTypeTestConfig");
ret = system("echo [execution_plan] >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo op1=data_type:float32 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo \"op2=\\\"data_type:float16\\\"\" >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo [onther_plan2] >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo op1=data_type:float16 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
std::string filename = "MixDataTypeTestConfig";
std::string sectionname = "execution_plan";
auto execution_plan = lite::GetSectionInfoFromConfigFile(filename, sectionname);
std::map<std::string, std::string> config_info;
ret = lite::GetSectionInfoFromConfigFile(filename, sectionname, &config_info);
ASSERT_EQ(ret, 0);
ASSERT_EQ(config_info.size(), 2);
auto info0 = config_info.at("op1");
ASSERT_EQ(info0, "data_type:float32");
auto inf01 = config_info.at("op2");
ASSERT_EQ(inf01, "\"data_type:float16\"");
std::map<std::string, TypeId> execution_plan;
lite::ParserExecutionPlan(&config_info, &execution_plan);
ASSERT_EQ(execution_plan.size(), 2);
auto info0 = execution_plan.at("op1");
ASSERT_EQ(info0, "data_type:fp32");
auto exe_info0 = execution_plan.at("op1");
ASSERT_EQ(exe_info0, kNumberTypeFloat32);
auto inf01 = execution_plan.at("op2");
ASSERT_EQ(inf01, "\"data_type:fp16\"");
auto exe_inf01 = execution_plan.at("op2");
ASSERT_EQ(exe_inf01, kNumberTypeFloat16);
}
void ConstructConfig() {
auto ret = system("echo [execution_plan] > MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo op1=data_type:float16 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
ret = system("echo op2=data_type:float32 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
/* op3 in fp16 */
ret = system("echo op4=data_type:float32 >> MixDataTypeTestConfig");
ASSERT_EQ(ret, 0);
}
void ConstructModel(schema::MetaGraphT *meta_graph) {
meta_graph->name = "mix_data_type_graph";
meta_graph->version = mindspore::lite::Version();
auto cos = std::make_unique<mindspore::schema::CNodeT>();
cos->inputIndex = {0};
cos->outputIndex = {1};
cos->primitive = std::make_unique<mindspore::schema::PrimitiveT>();
cos->primitive->value.type = mindspore::schema::PrimitiveType_Cos;
auto cos_primitive = new mindspore::schema::CosT;
cos->primitive->value.value = cos_primitive;
cos->name = "op1";
auto exp = std::make_unique<mindspore::schema::CNodeT>();
exp->inputIndex = {1};
exp->outputIndex = {2};
exp->primitive = std::make_unique<mindspore::schema::PrimitiveT>();
exp->primitive->value.type = mindspore::schema::PrimitiveType_ExpFusion;
auto exp_primitive = new mindspore::schema::ExpFusionT;
exp->primitive->value.value = exp_primitive;
exp->name = "op2";
auto sin = std::make_unique<mindspore::schema::CNodeT>();
sin->inputIndex = {2};
sin->outputIndex = {3};
sin->primitive = std::make_unique<mindspore::schema::PrimitiveT>();
sin->primitive->value.type = mindspore::schema::PrimitiveType_Sin;
auto sin_primitive = new mindspore::schema::SinT;
sin->primitive->value.value = sin_primitive;
sin->name = "op3";
auto cos2 = std::make_unique<mindspore::schema::CNodeT>();
cos2->inputIndex = {3};
cos2->outputIndex = {4};
cos2->primitive = std::make_unique<mindspore::schema::PrimitiveT>();
cos2->primitive->value.type = mindspore::schema::PrimitiveType_Cos;
auto cos2_primitive = new mindspore::schema::CosT;
cos2->primitive->value.value = cos2_primitive;
cos2->name = "op4";
/* tensors */
auto tensor0 = std::make_unique<mindspore::schema::TensorT>();
tensor0->nodeType = mindspore::lite::NodeType_ValueNode;
tensor0->format = mindspore::schema::Format_NHWC;
tensor0->dataType = mindspore::TypeId::kNumberTypeFloat32;
tensor0->dims = {1, 2, 2, 1};
tensor0->offset = -1;
tensor0->name = "tensor0";
auto tensor1 = std::make_unique<mindspore::schema::TensorT>();
tensor1->nodeType = mindspore::lite::NodeType_ValueNode;
tensor1->format = mindspore::schema::Format_NHWC;
tensor1->dataType = mindspore::TypeId::kNumberTypeFloat32;
tensor1->dims = {1, 2, 2, 1};
tensor1->offset = -1;
tensor1->name = "tensor1";
auto tensor2 = std::make_unique<mindspore::schema::TensorT>();
tensor2->nodeType = mindspore::lite::NodeType_ValueNode;
tensor2->format = mindspore::schema::Format_NHWC;
tensor2->dataType = mindspore::TypeId::kNumberTypeFloat32;
tensor2->dims = {1, 2, 2, 1};
tensor2->offset = -1;
tensor2->name = "tensor2";
auto tensor3 = std::make_unique<mindspore::schema::TensorT>();
tensor3->nodeType = mindspore::lite::NodeType_ValueNode;
tensor3->format = mindspore::schema::Format_NHWC;
tensor3->dataType = mindspore::TypeId::kNumberTypeFloat32;
tensor3->dims = {1, 2, 2, 1};
tensor3->offset = -1;
tensor3->name = "tensor3";
auto tensor4 = std::make_unique<mindspore::schema::TensorT>();
tensor4->nodeType = mindspore::lite::NodeType_ValueNode;
tensor4->format = mindspore::schema::Format_NHWC;
tensor4->dataType = mindspore::TypeId::kNumberTypeFloat32;
tensor4->dims = {1, 2, 2, 1};
tensor4->offset = -1;
tensor4->name = "tensor4";
meta_graph->nodes.emplace_back(std::move(cos));
meta_graph->nodes.emplace_back(std::move(exp));
meta_graph->nodes.emplace_back(std::move(sin));
meta_graph->nodes.emplace_back(std::move(cos2));
meta_graph->allTensors.emplace_back(std::move(tensor0));
meta_graph->allTensors.emplace_back(std::move(tensor1));
meta_graph->allTensors.emplace_back(std::move(tensor2));
meta_graph->allTensors.emplace_back(std::move(tensor3));
meta_graph->allTensors.emplace_back(std::move(tensor4));
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {4};
}
TEST_F(MixDataTypeTest, mix1) {
ConstructConfig();
size_t size;
auto meta_graph = std::make_shared<schema::MetaGraphT>();
ConstructModel(meta_graph.get());
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size = builder.GetSize();
auto flat_model = reinterpret_cast<char *>(builder.GetBufferPointer());
auto context = std::make_shared<mindspore::Context>();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(device_info);
auto impl = std::make_shared<mindspore::ModelImpl>();
auto status = impl->LoadConfig("MixDataTypeTestConfig");
ASSERT_EQ(status, kSuccess);
status = impl->Build(flat_model, size, kFlatBuffer, context);
ASSERT_EQ(status, kSuccess);
/* check */
auto kernels = reinterpret_cast<const lite::LiteSession *>(impl->GetSession())->get_kernels();
ASSERT_EQ(4, kernels.size());
ASSERT_EQ(kNumberTypeFloat16, kernels.at(0)->desc().data_type);
ASSERT_EQ(kNumberTypeFloat32, kernels.at(1)->desc().data_type);
ASSERT_EQ(kNumberTypeFloat16, kernels.at(2)->desc().data_type);
ASSERT_EQ(kNumberTypeFloat32, kernels.at(3)->desc().data_type);
/* set input data */
std::vector<mindspore::MSTensor> inputs = impl->GetInputs();
auto in = inputs[0];
std::vector<float> in_float = {1.0, 2.0, 3.0, 4.0};
memcpy(in.MutableData(), in_float.data(), in.DataSize());
std::vector<mindspore::MSTensor> outputs = impl->GetOutputs();
impl->Predict(inputs, &outputs, nullptr, nullptr);
/* checkout output */
auto out = outputs[0];
void *out_data = out.MutableData();
float *fp32_data = reinterpret_cast<float *>(out_data);
ASSERT_LE(fabs(fp32_data[0] - (0.549187)), 0.01);
ASSERT_LE(fabs(fp32_data[1] - (0.818051)), 0.01);
ASSERT_LE(fabs(fp32_data[2] - (0.934805)), 0.01);
ASSERT_LE(fabs(fp32_data[3] - (0.879054)), 0.01);
}
} // namespace mindspore

View File

@ -428,7 +428,7 @@ TEST_F(MultipleDeviceTest, NewApi5) {
model_impl->Predict(inputs, &outputs, nullptr, nullptr);
/* checkout output */
auto out = outputs[0]; /* output data control by users */
auto out = outputs[0];
void *out_data = out.MutableData();
float *fp32_data = reinterpret_cast<float *>(out_data);
@ -452,8 +452,8 @@ TEST_F(MultipleDeviceTest, NewApi6) {
auto context = std::make_shared<mindspore::Context>();
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
// context->MutableDeviceInfo().push_back(std::make_shared<mindspore::KirinNPUDeviceInfo>());
// context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::KirinNPUDeviceInfo>());
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
auto model_impl = std::make_shared<mindspore::ModelImpl>();
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
@ -473,7 +473,7 @@ TEST_F(MultipleDeviceTest, NewApi6) {
model_impl->Predict(inputs, &outputs, nullptr, nullptr);
/* checkout output */
auto out = outputs[0]; /* output data control by users */
auto out = outputs[0];
void *out_data = out.MutableData();
float *fp32_data = reinterpret_cast<float *>(out_data);