api support dual abi

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-02-19 17:09:40 +08:00
parent 674bca3fe4
commit 504f215800
18 changed files with 425 additions and 140 deletions

View File

@ -16,11 +16,11 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <map>
#include <any>
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector>
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
constexpr auto kDeviceTypeAscend310 = "Ascend310"; constexpr auto kDeviceTypeAscend310 = "Ascend310";
@ -28,38 +28,108 @@ constexpr auto kDeviceTypeAscend910 = "Ascend910";
constexpr auto kDeviceTypeGPU = "GPU"; constexpr auto kDeviceTypeGPU = "GPU";
struct MS_API Context { struct MS_API Context {
public:
Context();
virtual ~Context() = default; virtual ~Context() = default;
std::map<std::string, std::any> params; struct Data;
std::shared_ptr<Data> data;
}; };
struct MS_API GlobalContext : public Context { struct MS_API GlobalContext : public Context {
public:
static std::shared_ptr<Context> GetGlobalContext(); static std::shared_ptr<Context> GetGlobalContext();
static void SetGlobalDeviceTarget(const std::string &device_target); static inline void SetGlobalDeviceTarget(const std::string &device_target);
static std::string GetGlobalDeviceTarget(); static inline std::string GetGlobalDeviceTarget();
static void SetGlobalDeviceID(const uint32_t &device_id); static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID(); static uint32_t GetGlobalDeviceID();
private:
// api without std::string
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
static std::vector<char> GetGlobalDeviceTargetChar();
}; };
struct MS_API ModelContext : public Context { struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path); public:
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context); static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format); static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context); static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape); static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context); static inline std::string GetInputShape(const std::shared_ptr<Context> &context);
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type); static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context); static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode); static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context); static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode); static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context); const std::string &op_select_impl_mode);
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
private:
// api without std::string
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::vector<char> &op_select_impl_mode);
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
}; };
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
SetGlobalDeviceTarget(StringToChar(device_target));
}
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
SetInsertOpConfigPath(context, StringToChar(cfg_path));
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
return CharToString(GetInsertOpConfigPathChar(context));
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
SetInputFormat(context, StringToChar(format));
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
return CharToString(GetInputFormatChar(context));
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
SetInputShape(context, StringToChar(shape));
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
return CharToString(GetInputShapeChar(context));
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
SetPrecisionMode(context, StringToChar(precision_mode));
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
return CharToString(GetPrecisionModeChar(context));
}
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
return CharToString(GetOpSelectImplModeChar(context));
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H #endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
#include <string>
#include <vector>
namespace mindspore {
inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); }
inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_

View File

@ -17,7 +17,6 @@
#define MINDSPORE_INCLUDE_API_GRAPH_H #define MINDSPORE_INCLUDE_API_GRAPH_H
#include <cstddef> #include <cstddef>
#include <string>
#include <vector> #include <vector>
#include <map> #include <map>
#include <memory> #include <memory>

View File

@ -25,6 +25,7 @@
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/cell.h" #include "include/api/cell.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
class ModelImpl; class ModelImpl;
@ -46,10 +47,16 @@ class MS_API Model {
std::vector<MSTensor> GetInputs(); std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs(); std::vector<MSTensor> GetOutputs();
static bool CheckModelSupport(const std::string &device_type, ModelType model_type); static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private: private:
// api without std::string
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
std::shared_ptr<ModelImpl> impl_; std::shared_ptr<ModelImpl> impl_;
}; };
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
return CheckModelSupport(StringToChar(device_type), model_type);
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H #endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -24,16 +24,24 @@
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/model.h" #include "include/api/model.h"
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
class MS_API Serialization { class MS_API Serialization {
public: public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type); static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type); inline static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
private:
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
}; };
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
return LoadModel(StringToChar(file), model_type);
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -16,9 +16,13 @@
#ifndef MINDSPORE_INCLUDE_API_STATUS_H #ifndef MINDSPORE_INCLUDE_API_STATUS_H
#define MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H
#include <memory>
#include <string> #include <string>
#include <vector>
#include <ostream> #include <ostream>
#include <climits> #include <climits>
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
namespace mindspore { namespace mindspore {
enum CompCode : uint32_t { enum CompCode : uint32_t {
@ -100,46 +104,61 @@ enum StatusCode : uint32_t {
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */ kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
}; };
class Status { class MS_API Status {
public: public:
Status() : status_code_(kSuccess), line_of_code_(-1) {} Status();
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {} inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
~Status() = default; ~Status() = default;
enum StatusCode StatusCode() const { return status_code_; } enum StatusCode StatusCode() const;
const std::string &ToString() const { return status_msg_; } inline std::string ToString() const;
int GetLineOfCode() const { return line_of_code_; } int GetLineOfCode() const;
const std::string &GetErrDescription() const { return status_msg_; } inline std::string GetErrDescription() const;
const std::string &SetErrDescription(const std::string &err_description); inline std::string SetErrDescription(const std::string &err_description);
friend std::ostream &operator<<(std::ostream &os, const Status &s); friend std::ostream &operator<<(std::ostream &os, const Status &s);
bool operator==(const Status &other) const { return status_code_ == other.status_code_; } bool operator==(const Status &other) const;
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } bool operator==(enum StatusCode other_code) const;
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } bool operator!=(const Status &other) const;
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } bool operator!=(enum StatusCode other_code) const;
explicit operator bool() const { return (status_code_ == kSuccess); } explicit operator bool() const;
explicit operator int() const { return static_cast<int>(status_code_); } explicit operator int() const;
static Status OK() { return Status(StatusCode::kSuccess); } static Status OK();
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); } bool IsOk() const;
bool IsError() const { return !IsOk(); } bool IsError() const;
static std::string CodeAsString(enum StatusCode c); static inline std::string CodeAsString(enum StatusCode c);
private: private:
enum StatusCode status_code_; // api without std::string
std::string status_msg_; explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg);
int line_of_code_; Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
std::string file_name_; std::vector<char> ToCString() const;
std::string err_description_; std::vector<char> GetErrDescriptionChar() const;
std::vector<char> SetErrDescription(const std::vector<char> &err_description);
static std::vector<char> CodeAsCString(enum StatusCode c);
struct Data;
std::shared_ptr<Data> data_;
}; };
Status::Status(enum StatusCode status_code, const std::string &status_msg)
: Status(status_code, StringToChar(status_msg)) {}
Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra)
: Status(code, line_of_code, file_name, StringToChar(extra)) {}
std::string Status::ToString() const { return CharToString(ToCString()); }
std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); }
std::string Status::SetErrDescription(const std::string &err_description) {
return CharToString(SetErrDescription(StringToChar(err_description)));
}
std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); }
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H #endif // MINDSPORE_INCLUDE_API_STATUS_H

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "include/api/data_type.h" #include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"
#ifdef _WIN32 #ifdef _WIN32
#define MS_API __declspec(dllexport) #define MS_API __declspec(dllexport)
@ -42,18 +43,18 @@ class MS_API MSTensor {
public: public:
class Impl; class Impl;
static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept; const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept; const void *data, size_t data_len) noexcept;
MSTensor(); MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl); explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len); size_t data_len);
~MSTensor(); ~MSTensor();
const std::string &Name() const; inline std::string Name() const;
enum DataType DataType() const; enum DataType DataType() const;
const std::vector<int64_t> &Shape() const; const std::vector<int64_t> &Shape() const;
int64_t ElementNum() const; int64_t ElementNum() const;
@ -68,6 +69,15 @@ class MS_API MSTensor {
bool operator==(std::nullptr_t) const; bool operator==(std::nullptr_t) const;
private: private:
// api without std::string
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
std::vector<char> CharName() const;
friend class ModelImpl; friend class ModelImpl;
explicit MSTensor(std::nullptr_t); explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
@ -92,5 +102,21 @@ class MS_API Buffer {
class Impl; class Impl;
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
std::string MSTensor::Name() const { return CharToString(CharName()); }
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H #endif // MINDSPORE_INCLUDE_API_TYPES_H

View File

@ -14,6 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include "include/api/context.h" #include "include/api/context.h"
#include <any>
#include <map>
#include <type_traits>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
namespace mindspore { namespace mindspore {
template <class T> struct Context::Data {
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) { std::map<std::string, std::any> params;
auto iter = context->params.find(key); };
if (iter == context->params.end()) {
return T(); Context::Context() : data(std::make_shared<Data>()) {}
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
static U empty_result;
if (context == nullptr || context->data == nullptr) {
return empty_result;
}
auto iter = context->data->params.find(key);
if (iter == context->data->params.end()) {
return empty_result;
} }
const std::any &value = iter->second; const std::any &value = iter->second;
if (value.type() != typeid(T)) { if (value.type() != typeid(U)) {
return T(); return empty_result;
} }
return std::any_cast<T>(value); return std::any_cast<const U &>(value);
} }
std::shared_ptr<Context> GlobalContext::GetGlobalContext() { std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
return g_context; return g_context;
} }
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
auto global_context = GetGlobalContext(); auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context); MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceTarget] = device_target; if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
} }
std::string GlobalContext::GetGlobalDeviceTarget() { std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
auto global_context = GetGlobalContext(); auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context); MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget); const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
return StringToChar(ref);
} }
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) { void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
auto global_context = GetGlobalContext(); auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context); MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceID] = device_id; if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceID] = device_id;
} }
uint32_t GlobalContext::GetGlobalDeviceID() { uint32_t GlobalContext::GetGlobalDeviceID() {
@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID); return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
} }
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) { void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
} }
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) { std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath); const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
return StringToChar(ref);
} }
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) { void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputFormat] = CharToString(format);
} }
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) { std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat); const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
return StringToChar(ref);
} }
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) { void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputShape] = CharToString(shape);
} }
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) { std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape); const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
return StringToChar(ref);
} }
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) { void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOutputType] = output_type;
} }
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) { enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex
return GetValue<enum DataType>(context, kModelOptionOutputType); return GetValue<enum DataType>(context, kModelOptionOutputType);
} }
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) { void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
} }
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) { std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode); const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
return StringToChar(ref);
} }
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context, void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) { const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode; if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
} }
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) { std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode); const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
return StringToChar(ref);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -68,12 +68,13 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
Model::~Model() {} Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) {
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) { std::string device_type_str = CharToString(device_type);
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
return false; return false;
} }
auto first_iter = kSupportedModelMap.find(device_type); auto first_iter = kSupportedModelMap.find(device_type_str);
if (first_iter == kSupportedModelMap.end()) { if (first_iter == kSupportedModelMap.end()) {
return false; return false;
} }

View File

@ -84,17 +84,18 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
} }
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
std::string file_path = CharToString(file);
if (model_type == kMindIR) { if (model_type == kMindIR) {
FuncGraphPtr anf_graph = LoadMindIR(file); FuncGraphPtr anf_graph = LoadMindIR(file_path);
if (anf_graph == nullptr) { if (anf_graph == nullptr) {
MS_LOG(EXCEPTION) << "Load model failed."; MS_LOG(EXCEPTION) << "Load model failed.";
} }
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) { } else if (model_type == kOM) {
Buffer data = ReadFile(file); Buffer data = ReadFile(file_path);
if (data.Data() == nullptr) { if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed."; MS_LOG(EXCEPTION) << "Read file " << file_path << " failed.";
} }
return Graph(std::make_shared<Graph::GraphData>(data, kOM)); return Graph(std::make_shared<Graph::GraphData>(data, kOM));
} }

View File

@ -134,10 +134,11 @@ class TensorReferenceImpl : public MSTensor::Impl {
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
}; };
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept { const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try { try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len); std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl); return MSTensor(impl);
} catch (const std::bad_alloc &) { } catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed."; MS_LOG(ERROR) << "Malloc memory failed.";
@ -148,10 +149,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
} }
} }
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept { const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try { try {
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len); std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl); return MSTensor(impl);
} catch (const std::bad_alloc &) { } catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed."; MS_LOG(ERROR) << "Malloc memory failed.";
@ -165,9 +167,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type,
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {} MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); } MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
size_t data_len) const void *data, size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {} : impl_(std::make_shared<TensorDefaultImpl>(CharToString(name), type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default; MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
@ -179,9 +181,9 @@ MSTensor MSTensor::Clone() const {
return ret; return ret;
} }
const std::string &MSTensor::Name() const { std::vector<char> MSTensor::CharName() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Name(); return StringToChar(impl_->Name());
} }
enum DataType MSTensor::DataType() const { enum DataType MSTensor::DataType() const {

View File

@ -296,8 +296,7 @@ else()
endif() endif()
endif() endif()
add_dependencies(_c_dataengine mindspore_shared_lib) target_link_libraries(_c_dataengine PRIVATE mindspore_core mindspore_shared_lib)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
if(USE_GLOG) if(USE_GLOG)
target_link_libraries(_c_dataengine PRIVATE mindspore::glog) target_link_libraries(_c_dataengine PRIVATE mindspore::glog)

View File

@ -686,6 +686,9 @@ class Tensor {
/// pointer to the end of the physical data /// pointer to the end of the physical data
unsigned char *data_end_ = nullptr; unsigned char *data_end_ = nullptr;
/// shape for interpretation of YUV image
std::vector<uint32_t> yuv_shape_;
private: private:
friend class DETensor; friend class DETensor;

View File

@ -24,16 +24,42 @@
#include <sstream> #include <sstream>
namespace mindspore { namespace mindspore {
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { struct Status::Data {
status_code_ = code; enum StatusCode status_code = kSuccess;
line_of_code_ = line_of_code; std::string status_msg;
file_name_ = std::string(file_name); int line_of_code = -1;
err_description_ = extra; std::string file_name;
std::string err_description;
};
Status::Status() : data_(std::make_shared<Data>()) {}
Status::Status(enum StatusCode status_code, const std::vector<char> &status_msg) : data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}
data_->status_msg = CharToString(status_msg);
data_->status_code = status_code;
}
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra)
: data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}
data_->status_code = code;
data_->line_of_code = line_of_code;
if (file_name != nullptr) {
data_->file_name = file_name;
}
data_->err_description = CharToString(extra);
std::ostringstream ss; std::ostringstream ss;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". "; ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". ";
if (!extra.empty()) { if (!data_->err_description.empty()) {
ss << extra; ss << data_->err_description;
} }
ss << "\n"; ss << "\n";
#endif #endif
@ -42,10 +68,38 @@ Status::Status(enum StatusCode code, int line_of_code, const char *file_name, co
if (file_name != nullptr) { if (file_name != nullptr) {
ss << "File : " << file_name << "\n"; ss << "File : " << file_name << "\n";
} }
status_msg_ = ss.str(); data_->status_msg = ss.str();
} }
std::string Status::CodeAsString(enum StatusCode c) { enum StatusCode Status::StatusCode() const {
if (data_ == nullptr) {
return kSuccess;
}
return data_->status_code;
}
std::vector<char> Status::ToCString() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}
int Status::GetLineOfCode() const {
if (data_ == nullptr) {
return -1;
}
return data_->line_of_code;
}
std::vector<char> Status::GetErrDescriptionChar() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}
std::vector<char> Status::CodeAsCString(enum StatusCode c) {
static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."}, static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."},
// Core // Core
{kCoreFailed, "Common error code."}, {kCoreFailed, "Common error code."},
@ -98,7 +152,7 @@ std::string Status::CodeAsString(enum StatusCode c) {
{kLiteInferInvalid, "Invalid infer shape before runtime."}, {kLiteInferInvalid, "Invalid infer shape before runtime."},
{kLiteInputParamInvalid, "Invalid input param by user."}}; {kLiteInputParamInvalid, "Invalid input param by user."}};
auto iter = info_map.find(c); auto iter = info_map.find(c);
return iter == info_map.end() ? "Unknown error" : iter->second; return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second);
} }
std::ostream &operator<<(std::ostream &os, const Status &s) { std::ostream &operator<<(std::ostream &os, const Status &s) {
@ -106,22 +160,48 @@ std::ostream &operator<<(std::ostream &os, const Status &s) {
return os; return os;
} }
const std::string &Status::SetErrDescription(const std::string &err_description) { std::vector<char> Status::SetErrDescription(const std::vector<char> &err_description) {
err_description_ = err_description; if (data_ == nullptr) {
return std::vector<char>();
}
data_->err_description = CharToString(err_description);
std::ostringstream ss; std::ostringstream ss;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(status_code_) << ". "; ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". ";
if (!err_description_.empty()) { if (!data_->err_description.empty()) {
ss << err_description_; ss << data_->err_description;
} }
ss << "\n"; ss << "\n";
#endif #endif
if (line_of_code_ > 0 && !file_name_.empty()) { if (data_->line_of_code > 0 && !data_->file_name.empty()) {
ss << "Line of code : " << line_of_code_ << "\n"; ss << "Line of code : " << data_->line_of_code << "\n";
ss << "File : " << file_name_ << "\n"; ss << "File : " << data_->file_name << "\n";
} }
status_msg_ = ss.str(); data_->status_msg = ss.str();
return status_msg_; return StringToChar(data_->status_msg);
} }
bool Status::operator==(const Status &other) const {
if (data_ == nullptr && other.data_ == nullptr) {
return true;
}
if (data_ == nullptr || other.data_ == nullptr) {
return false;
}
return data_->status_code == other.data_->status_code;
}
bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; }
bool Status::operator!=(const Status &other) const { return !operator==(other); }
bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); }
Status::operator bool() const { return (StatusCode() == kSuccess); }
Status::operator int() const { return static_cast<int>(StatusCode()); }
Status Status::OK() { return StatusCode::kSuccess; }
bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool Status::IsError() const { return !IsOk(); }
} // namespace mindspore } // namespace mindspore

View File

@ -67,7 +67,7 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
Model::~Model() {} Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType) { bool Model::CheckModelSupport(const std::vector<char> &, ModelType) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return false; return false;
} }

View File

@ -47,7 +47,7 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
return graph; return graph;
} }
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Unsupported Feature.";
return Graph(nullptr); return Graph(nullptr);
} }

View File

@ -60,16 +60,16 @@ class Buffer::Impl {
MSTensor::MSTensor() : impl_(std::make_shared<Impl>()) {} MSTensor::MSTensor() : impl_(std::make_shared<Impl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {} MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {}
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
size_t data_len) const void *data, size_t data_len)
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {} : impl_(std::make_shared<Impl>(CharToString(name), type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default; MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept { const void *data, size_t data_len) noexcept {
auto impl = std::make_shared<Impl>(name, type, shape, data, data_len); auto impl = std::make_shared<Impl>(CharToString(name), type, shape, data, data_len);
if (impl == nullptr) { if (impl == nullptr) {
MS_LOG(ERROR) << "Allocate tensor impl failed."; MS_LOG(ERROR) << "Allocate tensor impl failed.";
return MSTensor(nullptr); return MSTensor(nullptr);
@ -77,7 +77,7 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
return MSTensor(impl); return MSTensor(impl);
} }
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept { const void *data, size_t data_len) noexcept {
auto tensor = CreateTensor(name, type, shape, data, data_len); auto tensor = CreateTensor(name, type, shape, data, data_len);
if (tensor == nullptr) { if (tensor == nullptr) {
@ -98,13 +98,12 @@ MSTensor MSTensor::Clone() const {
return ret; return ret;
} }
const std::string &MSTensor::Name() const { std::vector<char> MSTensor::CharName() const {
static std::string empty = "";
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor inpmlement."; MS_LOG(ERROR) << "Invalid tensor inpmlement.";
return empty; return std::vector<char>();
} }
return impl_->Name(); return StringToChar(impl_->Name());
} }
int64_t MSTensor::ElementNum() const { int64_t MSTensor::ElementNum() const {

View File

@ -60,12 +60,6 @@ TEST_F(TestCxxApiContext, test_context_ascend310_context_nullptr_FAILED) {
EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr)); EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr));
} }
TEST_F(TestCxxApiContext, test_context_ascend310_context_wrong_type_SUCCESS) {
auto ctx = std::make_shared<ModelContext>();
ctx->params["mindspore.option.op_select_impl_mode"] = 5;
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");
}
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) { TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
auto ctx = std::make_shared<ModelContext>(); auto ctx = std::make_shared<ModelContext>();
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), ""); ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");