!22070 [lite]enhance dumpgraph function and adjust examples on the basis of unified api
Merge pull request !22070 from 徐安越/master_core
This commit is contained in:
commit
0d6559092d
|
@ -14,206 +14,97 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cstring>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/model.h"
|
||||
#include "include/context.h"
|
||||
#include "include/lite_session.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kNumPrintOfOutData = 20;
|
||||
std::string RealPath(const char *path) {
|
||||
const size_t max = 4096;
|
||||
if (path == nullptr) {
|
||||
std::cerr << "path is nullptr" << std::endl;
|
||||
return "";
|
||||
Status FillInputData(const std::vector<mindspore::MSTensor> &inputs) {
|
||||
for (auto tensor : inputs) {
|
||||
auto input_data = tensor.MutableData();
|
||||
if (input_data == nullptr) {
|
||||
std::cerr << "MallocData for inTensor failed.\n";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<float> temp(tensor.ElementNum(), 1.0f);
|
||||
memcpy(input_data, temp.data(), tensor.DataSize());
|
||||
}
|
||||
if ((strlen(path)) >= max) {
|
||||
std::cerr << "path is too long" << std::endl;
|
||||
return "";
|
||||
}
|
||||
auto resolved_path = std::make_unique<char[]>(max);
|
||||
if (resolved_path == nullptr) {
|
||||
std::cerr << "new resolved_path failed" << std::endl;
|
||||
return "";
|
||||
}
|
||||
#ifdef _WIN32
|
||||
char *real_path = _fullpath(resolved_path.get(), path, 1024);
|
||||
#else
|
||||
char *real_path = realpath(path, resolved_path.get());
|
||||
#endif
|
||||
if (real_path == nullptr || strlen(real_path) == 0) {
|
||||
std::cerr << "file path is not valid : " << path << std::endl;
|
||||
return "";
|
||||
}
|
||||
std::string res = resolved_path.get();
|
||||
return res;
|
||||
}
|
||||
|
||||
char *ReadFile(const char *file, size_t *size) {
|
||||
if (file == nullptr) {
|
||||
std::cerr << "file is nullptr." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cerr << "file: " << file << " is not exist." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cerr << "file: " << file << " open failed." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
*size = ifs.tellg();
|
||||
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
|
||||
if (buf == nullptr) {
|
||||
std::cerr << "malloc buf failed, file: " << file << std::endl;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf.get(), *size);
|
||||
ifs.close();
|
||||
|
||||
return buf.release();
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename Distribution>
|
||||
void GenerateRandomData(int size, void *data, Distribution distribution) {
|
||||
std::mt19937 random_engine;
|
||||
int elements_num = size / sizeof(T);
|
||||
(void)std::generate_n(static_cast<T *>(data), elements_num,
|
||||
[&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
|
||||
}
|
||||
|
||||
int GenerateInputDataWithRandom(std::vector<mindspore::tensor::MSTensor *> inputs) {
|
||||
for (auto tensor : inputs) {
|
||||
auto input_data = tensor->MutableData();
|
||||
if (input_data == nullptr) {
|
||||
std::cerr << "MallocData for inTensor failed." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
GenerateRandomData<float>(tensor->Size(), input_data, std::uniform_real_distribution<float>(1.0f, 1.0f));
|
||||
Status CompileAndRun(int argc, const char **argv) {
|
||||
if (argc < 2) {
|
||||
std::cerr << "Model file must be provided.\n";
|
||||
return kLiteError;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
// generate context.
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
std::cerr << "New context failed while running.\n";
|
||||
return kLiteError;
|
||||
}
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetProvider("Tutorial");
|
||||
device_info->SetProviderDevice("Tutorial");
|
||||
device_list.push_back(device_info);
|
||||
|
||||
int Run(mindspore::session::LiteSession *session) {
|
||||
auto inputs = session->GetInputs();
|
||||
// build model.
|
||||
std::string model_file = std::string(argv[1]);
|
||||
mindspore::Model model;
|
||||
auto ret = model.Build(model_file, kMindIR, context);
|
||||
if (ret != kSuccess) {
|
||||
std::cerr << "build model failed.\n";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
// Generate random data as input data.
|
||||
auto ret = GenerateInputDataWithRandom(inputs);
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Generate Random Input Data failed." << std::endl;
|
||||
// fill input data.
|
||||
auto inputs = model.GetInputs();
|
||||
ret = FillInputData(inputs);
|
||||
if (ret != kSuccess) {
|
||||
std::cerr << "Generate Random Input Data failed.\n";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Run Inference.
|
||||
ret = session->RunGraph();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Inference error " << ret << std::endl;
|
||||
// run model.
|
||||
std::vector<MSTensor> outputs;
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
if (ret != kSuccess) {
|
||||
std::cerr << "run model failed.\n";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Get Output Tensor Data.
|
||||
auto out_tensors = session->GetOutputs();
|
||||
for (auto tensor : out_tensors) {
|
||||
std::cout << "tensor name is:" << tensor.first << " tensor size is:" << tensor.second->Size()
|
||||
<< " tensor elements num is:" << tensor.second->ElementsNum() << std::endl;
|
||||
auto out_data = reinterpret_cast<float *>(tensor.second->MutableData());
|
||||
// display output result.
|
||||
for (auto tensor : outputs) {
|
||||
std::cout << "tensor name is:" << tensor.Name() << " tensor size is:" << tensor.DataSize()
|
||||
<< " tensor elements num is:" << tensor.ElementNum() << std::endl;
|
||||
auto out_data = std::static_pointer_cast<const float>(tensor.Data());
|
||||
std::cout << "output data is:";
|
||||
for (int i = 0; i < tensor.second->ElementsNum() && i <= kNumPrintOfOutData; i++) {
|
||||
std::cout << out_data[i] << " ";
|
||||
for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
|
||||
std::cout << out_data.get()[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
mindspore::session::LiteSession *Compile(mindspore::lite::Model *model) {
|
||||
// Create and init context.
|
||||
auto context = std::make_shared<Context>();
|
||||
if (context == nullptr) {
|
||||
std::cerr << "New context failed while." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
context->device_list_[0].provider_ = "Tutorial";
|
||||
context->device_list_[0].provider_device_ = "Tutorial";
|
||||
// Create the session.
|
||||
auto *session = mindspore::session::LiteSession::CreateSession(context.get());
|
||||
if (session == nullptr) {
|
||||
std::cerr << "CreateSession failed while running." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Compile graph.
|
||||
auto ret = session->CompileGraph(model);
|
||||
if (ret != RET_OK) {
|
||||
delete session;
|
||||
std::cerr << "Compile failed while running." << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return session;
|
||||
}
|
||||
|
||||
int CompileAndRun(int argc, const char **argv) {
|
||||
if (argc < 2) {
|
||||
std::cerr << "Model file must be provided.\n";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Read model file.
|
||||
auto model_path = RealPath(argv[1]);
|
||||
if (model_path.empty()) {
|
||||
std::cerr << "model path " << argv[1] << " is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t size = 0;
|
||||
char *model_buf = ReadFile(model_path.c_str(), &size);
|
||||
if (model_buf == nullptr) {
|
||||
std::cerr << "Read model file failed." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Load the .ms model.
|
||||
auto model = Model::Import(model_buf, size);
|
||||
delete[](model_buf);
|
||||
if (model == nullptr) {
|
||||
std::cerr << "Import model file failed." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Compile MindSpore Lite model.
|
||||
auto session = Compile(model);
|
||||
if (session == nullptr) {
|
||||
delete model;
|
||||
std::cerr << "Create session failed." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Run inference.
|
||||
auto ret = Run(session);
|
||||
if (ret != RET_OK) {
|
||||
delete model;
|
||||
delete session;
|
||||
std::cerr << "MindSpore Lite run failed." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Delete model buffer.
|
||||
delete model;
|
||||
// Delete session buffer.
|
||||
delete session;
|
||||
return RET_OK;
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
int main(int argc, const char **argv) { return mindspore::lite::CompileAndRun(argc, argv); }
|
||||
int main(int argc, const char **argv) {
|
||||
auto ret = mindspore::lite::CompileAndRun(argc, argv);
|
||||
if (ret != mindspore::kSuccess) {
|
||||
std::cerr << "run failed.\n";
|
||||
return -1;
|
||||
}
|
||||
std::cout << "run success.\n";
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -191,7 +191,8 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
(void)Manage(mirror_graph, true);
|
||||
if (!RunOptimizerPass(mirror_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
|
||||
if (!RunOptimizerPass(mirror_graph,
|
||||
{"ToNHWCFormat", "InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) {
|
||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -177,7 +177,8 @@ void UnifyFormatToNHWC::SetSensitiveOps() {
|
|||
sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end());
|
||||
}
|
||||
|
||||
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) {
|
||||
bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
|
||||
const ShapeVector &shape) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) {
|
||||
return false;
|
||||
|
|
|
@ -36,7 +36,8 @@ class UnifyFormatToNHWC : public opt::ToFormatBase {
|
|||
STATUS ConvertOnnxResizeForVariableShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override;
|
||||
void SetSensitiveOps() override;
|
||||
bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) override;
|
||||
bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
|
||||
const ShapeVector &shape) override;
|
||||
bool DecideWhetherInferShapeForNewNode() override;
|
||||
STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) override;
|
||||
|
|
|
@ -178,6 +178,31 @@ STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const C
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
|
||||
const ShapeVector &shape) {
|
||||
MS_ASSERT(func_graph != nullptr && input != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(anager != nullptr);
|
||||
auto node_users = manager->node_users()[input];
|
||||
for (auto &node_user : node_users) {
|
||||
auto post_node = node_user.first;
|
||||
if (!utils::isa<CNode>(post_node)) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = post_node->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
if (node_format == format_) {
|
||||
MS_LOG(DEBUG) << "this graph input don't need to change.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_input = func_graph->get_inputs();
|
||||
|
@ -191,7 +216,7 @@ STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, shape)) {
|
||||
if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, input_param, shape)) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector transfer_shape;
|
||||
|
@ -298,8 +323,6 @@ bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph)
|
|||
STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited) {
|
||||
MS_ASSERT(graph != nullptr && has_visited != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
schema::Format src_format = schema::Format_NUM_OF_FORMAT;
|
||||
schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
|
@ -335,6 +358,8 @@ STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<A
|
|||
continue;
|
||||
}
|
||||
has_visited->insert(node);
|
||||
schema::Format src_format = schema::Format_NUM_OF_FORMAT;
|
||||
schema::Format dst_format = schema::Format_NUM_OF_FORMAT;
|
||||
if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "weight's src format and dst format get failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -52,7 +52,8 @@ class ToFormatBase : public Pass {
|
|||
protected:
|
||||
virtual STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0;
|
||||
virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); }
|
||||
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { return true; }
|
||||
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
|
||||
const ShapeVector &shape);
|
||||
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
|
||||
virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) = 0;
|
||||
|
|
|
@ -45,6 +45,20 @@ STATUS ToNCHWFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTyp
|
|||
STATUS ToNCHWFormat::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
if (node_format == mindspore::NCHW) {
|
||||
MS_LOG(DEBUG) << "node's format has been nchw, no need to transfer, " << cnode->fullname_with_scope();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (node_format != mindspore::NHWC) {
|
||||
MS_LOG(ERROR) << "node's format is invalid, which must be nhwc or nchw, now is " << node_format
|
||||
<< ", node name is " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
*src_format = schema::Format_KHWC;
|
||||
*dst_format = schema::Format_KCHW;
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -45,6 +45,19 @@ STATUS ToNHWCFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTyp
|
|||
STATUS ToNHWCFormat::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
if (node_format == mindspore::NHWC) {
|
||||
MS_LOG(DEBUG) << "node's format has been nhwc, no need to transfer, " << cnode->fullname_with_scope();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (node_format != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "node's format is invalid, which must be nhwc or nchw, now is " << node_format
|
||||
<< ", node name is " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
*src_format = schema::Format_KCHW;
|
||||
*dst_format = schema::Format_KHWC;
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -136,7 +136,7 @@ bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
|||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
|
|
Loading…
Reference in New Issue