api support dual abi

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-02-19 17:09:40 +08:00
parent 674bca3fe4
commit 504f215800
18 changed files with 425 additions and 140 deletions

View File

@ -16,11 +16,11 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <map>
#include <any>
#include <string>
#include <memory>
#include <vector>
#include "include/api/types.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
constexpr auto kDeviceTypeAscend310 = "Ascend310";
@ -28,38 +28,108 @@ constexpr auto kDeviceTypeAscend910 = "Ascend910";
constexpr auto kDeviceTypeGPU = "GPU";
struct MS_API Context {
public:
Context();
virtual ~Context() = default;
std::map<std::string, std::any> params;
struct Data;
std::shared_ptr<Data> data;
};
struct MS_API GlobalContext : public Context {
public:
static std::shared_ptr<Context> GetGlobalContext();
static void SetGlobalDeviceTarget(const std::string &device_target);
static std::string GetGlobalDeviceTarget();
static inline void SetGlobalDeviceTarget(const std::string &device_target);
static inline std::string GetGlobalDeviceTarget();
static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();
private:
// api without std::string
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
static std::vector<char> GetGlobalDeviceTargetChar();
};
struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
public:
static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context);
static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static inline std::string GetInputShape(const std::shared_ptr<Context> &context);
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode);
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
private:
// api without std::string
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::vector<char> &op_select_impl_mode);
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
};
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
SetGlobalDeviceTarget(StringToChar(device_target));
}
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
SetInsertOpConfigPath(context, StringToChar(cfg_path));
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
return CharToString(GetInsertOpConfigPathChar(context));
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
SetInputFormat(context, StringToChar(format));
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
return CharToString(GetInputFormatChar(context));
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
SetInputShape(context, StringToChar(shape));
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
return CharToString(GetInputShapeChar(context));
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
SetPrecisionMode(context, StringToChar(precision_mode));
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
return CharToString(GetPrecisionModeChar(context));
}
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
return CharToString(GetOpSelectImplModeChar(context));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
#include <string>
#include <vector>
namespace mindspore {
inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); }
inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_

View File

@ -17,7 +17,6 @@
#define MINDSPORE_INCLUDE_API_GRAPH_H
#include <cstddef>
#include <string>
#include <vector>
#include <map>
#include <memory>

View File

@ -25,6 +25,7 @@
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/cell.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class ModelImpl;
@ -46,10 +47,16 @@ class MS_API Model {
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private:
// api without std::string
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
std::shared_ptr<ModelImpl> impl_;
};
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
return CheckModelSupport(StringToChar(device_type), model_type);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -24,16 +24,24 @@
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/graph.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class MS_API Serialization {
public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type);
inline static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
private:
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
};
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
return LoadModel(StringToChar(file), model_type);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -16,9 +16,13 @@
#ifndef MINDSPORE_INCLUDE_API_STATUS_H
#define MINDSPORE_INCLUDE_API_STATUS_H
#include <memory>
#include <string>
#include <vector>
#include <ostream>
#include <climits>
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
namespace mindspore {
enum CompCode : uint32_t {
@ -100,46 +104,61 @@ enum StatusCode : uint32_t {
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
};
class Status {
class MS_API Status {
public:
Status() : status_code_(kSuccess), line_of_code_(-1) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
Status();
inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit)
inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
~Status() = default;
enum StatusCode StatusCode() const { return status_code_; }
const std::string &ToString() const { return status_msg_; }
enum StatusCode StatusCode() const;
inline std::string ToString() const;
int GetLineOfCode() const { return line_of_code_; }
const std::string &GetErrDescription() const { return status_msg_; }
const std::string &SetErrDescription(const std::string &err_description);
int GetLineOfCode() const;
inline std::string GetErrDescription() const;
inline std::string SetErrDescription(const std::string &err_description);
friend std::ostream &operator<<(std::ostream &os, const Status &s);
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
bool operator==(const Status &other) const;
bool operator==(enum StatusCode other_code) const;
bool operator!=(const Status &other) const;
bool operator!=(enum StatusCode other_code) const;
explicit operator bool() const { return (status_code_ == kSuccess); }
explicit operator int() const { return static_cast<int>(status_code_); }
explicit operator bool() const;
explicit operator int() const;
static Status OK() { return Status(StatusCode::kSuccess); }
static Status OK();
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool IsOk() const;
bool IsError() const { return !IsOk(); }
bool IsError() const;
static std::string CodeAsString(enum StatusCode c);
static inline std::string CodeAsString(enum StatusCode c);
private:
enum StatusCode status_code_;
std::string status_msg_;
int line_of_code_;
std::string file_name_;
std::string err_description_;
// api without std::string
explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg);
Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
std::vector<char> ToCString() const;
std::vector<char> GetErrDescriptionChar() const;
std::vector<char> SetErrDescription(const std::vector<char> &err_description);
static std::vector<char> CodeAsCString(enum StatusCode c);
struct Data;
std::shared_ptr<Data> data_;
};
Status::Status(enum StatusCode status_code, const std::string &status_msg)
: Status(status_code, StringToChar(status_msg)) {}
Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra)
: Status(code, line_of_code, file_name, StringToChar(extra)) {}
std::string Status::ToString() const { return CharToString(ToCString()); }
std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); }
std::string Status::SetErrDescription(const std::string &err_description) {
return CharToString(SetErrDescription(StringToChar(err_description)));
}
std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H

View File

@ -21,6 +21,7 @@
#include <vector>
#include <memory>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
@ -42,18 +43,18 @@ class MS_API MSTensor {
public:
class Impl;
static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
~MSTensor();
const std::string &Name() const;
inline std::string Name() const;
enum DataType DataType() const;
const std::vector<int64_t> &Shape() const;
int64_t ElementNum() const;
@ -68,6 +69,15 @@ class MS_API MSTensor {
bool operator==(std::nullptr_t) const;
private:
// api without std::string
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
std::vector<char> CharName() const;
friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_;
@ -92,5 +102,21 @@ class MS_API Buffer {
class Impl;
std::shared_ptr<Impl> impl_;
};
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
std::string MSTensor::Name() const { return CharToString(CharName()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H

View File

@ -14,6 +14,9 @@
* limitations under the License.
*/
#include "include/api/context.h"
#include <any>
#include <map>
#include <type_traits>
#include "utils/log_adapter.h"
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
namespace mindspore {
template <class T>
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
auto iter = context->params.find(key);
if (iter == context->params.end()) {
return T();
struct Context::Data {
std::map<std::string, std::any> params;
};
Context::Context() : data(std::make_shared<Data>()) {}
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
static U empty_result;
if (context == nullptr || context->data == nullptr) {
return empty_result;
}
auto iter = context->data->params.find(key);
if (iter == context->data->params.end()) {
return empty_result;
}
const std::any &value = iter->second;
if (value.type() != typeid(T)) {
return T();
if (value.type() != typeid(U)) {
return empty_result;
}
return std::any_cast<T>(value);
return std::any_cast<const U &>(value);
}
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
return g_context;
}
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceTarget] = device_target;
if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
}
std::string GlobalContext::GetGlobalDeviceTarget() {
std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
return StringToChar(ref);
}
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceID] = device_id;
if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceID] = device_id;
}
uint32_t GlobalContext::GetGlobalDeviceID() {
@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
return StringToChar(ref);
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputFormat] = CharToString(format);
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat);
const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
return StringToChar(ref);
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputShape] = CharToString(shape);
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape);
const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
return StringToChar(ref);
}
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOutputType] = output_type;
}
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex
return GetValue<enum DataType>(context, kModelOptionOutputType);
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode);
const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
return StringToChar(ref);
}
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
return StringToChar(ref);
}
} // namespace mindspore

View File

@ -68,12 +68,13 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) {
bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) {
std::string device_type_str = CharToString(device_type);
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
return false;
}
auto first_iter = kSupportedModelMap.find(device_type);
auto first_iter = kSupportedModelMap.find(device_type_str);
if (first_iter == kSupportedModelMap.end()) {
return false;
}

View File

@ -84,17 +84,18 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
}
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
std::string file_path = CharToString(file);
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = LoadMindIR(file);
FuncGraphPtr anf_graph = LoadMindIR(file_path);
if (anf_graph == nullptr) {
MS_LOG(EXCEPTION) << "Load model failed.";
}
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) {
Buffer data = ReadFile(file);
Buffer data = ReadFile(file_path);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
MS_LOG(EXCEPTION) << "Read file " << file_path << " failed.";
}
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
}

View File

@ -134,10 +134,11 @@ class TensorReferenceImpl : public MSTensor::Impl {
std::vector<int64_t> shape_;
};
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
@ -148,10 +149,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
}
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
@ -165,9 +167,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type,
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(CharToString(name), type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
@ -179,9 +181,9 @@ MSTensor MSTensor::Clone() const {
return ret;
}
const std::string &MSTensor::Name() const {
std::vector<char> MSTensor::CharName() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Name();
return StringToChar(impl_->Name());
}
enum DataType MSTensor::DataType() const {

View File

@ -296,8 +296,7 @@ else()
endif()
endif()
add_dependencies(_c_dataengine mindspore_shared_lib)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
target_link_libraries(_c_dataengine PRIVATE mindspore_core mindspore_shared_lib)
if(USE_GLOG)
target_link_libraries(_c_dataengine PRIVATE mindspore::glog)

View File

@ -686,6 +686,9 @@ class Tensor {
/// pointer to the end of the physical data
unsigned char *data_end_ = nullptr;
/// shape for interpretation of YUV image
std::vector<uint32_t> yuv_shape_;
private:
friend class DETensor;

View File

@ -24,16 +24,42 @@
#include <sstream>
namespace mindspore {
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) {
status_code_ = code;
line_of_code_ = line_of_code;
file_name_ = std::string(file_name);
err_description_ = extra;
struct Status::Data {
enum StatusCode status_code = kSuccess;
std::string status_msg;
int line_of_code = -1;
std::string file_name;
std::string err_description;
};
Status::Status() : data_(std::make_shared<Data>()) {}
Status::Status(enum StatusCode status_code, const std::vector<char> &status_msg) : data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}
data_->status_msg = CharToString(status_msg);
data_->status_code = status_code;
}
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra)
: data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}
data_->status_code = code;
data_->line_of_code = line_of_code;
if (file_name != nullptr) {
data_->file_name = file_name;
}
data_->err_description = CharToString(extra);
std::ostringstream ss;
#ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". ";
if (!extra.empty()) {
ss << extra;
if (!data_->err_description.empty()) {
ss << data_->err_description;
}
ss << "\n";
#endif
@ -42,10 +68,38 @@ Status::Status(enum StatusCode code, int line_of_code, const char *file_name, co
if (file_name != nullptr) {
ss << "File : " << file_name << "\n";
}
status_msg_ = ss.str();
data_->status_msg = ss.str();
}
std::string Status::CodeAsString(enum StatusCode c) {
enum StatusCode Status::StatusCode() const {
if (data_ == nullptr) {
return kSuccess;
}
return data_->status_code;
}
std::vector<char> Status::ToCString() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}
int Status::GetLineOfCode() const {
if (data_ == nullptr) {
return -1;
}
return data_->line_of_code;
}
std::vector<char> Status::GetErrDescriptionChar() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}
std::vector<char> Status::CodeAsCString(enum StatusCode c) {
static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."},
// Core
{kCoreFailed, "Common error code."},
@ -98,7 +152,7 @@ std::string Status::CodeAsString(enum StatusCode c) {
{kLiteInferInvalid, "Invalid infer shape before runtime."},
{kLiteInputParamInvalid, "Invalid input param by user."}};
auto iter = info_map.find(c);
return iter == info_map.end() ? "Unknown error" : iter->second;
return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second);
}
std::ostream &operator<<(std::ostream &os, const Status &s) {
@ -106,22 +160,48 @@ std::ostream &operator<<(std::ostream &os, const Status &s) {
return os;
}
const std::string &Status::SetErrDescription(const std::string &err_description) {
err_description_ = err_description;
std::vector<char> Status::SetErrDescription(const std::vector<char> &err_description) {
if (data_ == nullptr) {
return std::vector<char>();
}
data_->err_description = CharToString(err_description);
std::ostringstream ss;
#ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(status_code_) << ". ";
if (!err_description_.empty()) {
ss << err_description_;
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". ";
if (!data_->err_description.empty()) {
ss << data_->err_description;
}
ss << "\n";
#endif
if (line_of_code_ > 0 && !file_name_.empty()) {
ss << "Line of code : " << line_of_code_ << "\n";
ss << "File : " << file_name_ << "\n";
if (data_->line_of_code > 0 && !data_->file_name.empty()) {
ss << "Line of code : " << data_->line_of_code << "\n";
ss << "File : " << data_->file_name << "\n";
}
status_msg_ = ss.str();
return status_msg_;
data_->status_msg = ss.str();
return StringToChar(data_->status_msg);
}
bool Status::operator==(const Status &other) const {
if (data_ == nullptr && other.data_ == nullptr) {
return true;
}
if (data_ == nullptr || other.data_ == nullptr) {
return false;
}
return data_->status_code == other.data_->status_code;
}
bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; }
bool Status::operator!=(const Status &other) const { return !operator==(other); }
bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); }
Status::operator bool() const { return (StatusCode() == kSuccess); }
Status::operator int() const { return static_cast<int>(StatusCode()); }
Status Status::OK() { return StatusCode::kSuccess; }
bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool Status::IsError() const { return !IsOk(); }
} // namespace mindspore

View File

@ -67,7 +67,7 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
bool Model::CheckModelSupport(const std::vector<char> &, ModelType) {
MS_LOG(ERROR) << "Unsupported feature.";
return false;
}

View File

@ -47,7 +47,7 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
return graph;
}
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
MS_LOG(ERROR) << "Unsupported Feature.";
return Graph(nullptr);
}

View File

@ -60,16 +60,16 @@ class Buffer::Impl {
MSTensor::MSTensor() : impl_(std::make_shared<Impl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {}
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {}
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len)
: impl_(std::make_shared<Impl>(CharToString(name), type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
auto impl = std::make_shared<Impl>(name, type, shape, data, data_len);
auto impl = std::make_shared<Impl>(CharToString(name), type, shape, data, data_len);
if (impl == nullptr) {
MS_LOG(ERROR) << "Allocate tensor impl failed.";
return MSTensor(nullptr);
@ -77,7 +77,7 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
return MSTensor(impl);
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
auto tensor = CreateTensor(name, type, shape, data, data_len);
if (tensor == nullptr) {
@ -98,13 +98,12 @@ MSTensor MSTensor::Clone() const {
return ret;
}
const std::string &MSTensor::Name() const {
static std::string empty = "";
std::vector<char> MSTensor::CharName() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor inpmlement.";
return empty;
return std::vector<char>();
}
return impl_->Name();
return StringToChar(impl_->Name());
}
int64_t MSTensor::ElementNum() const {

View File

@ -60,12 +60,6 @@ TEST_F(TestCxxApiContext, test_context_ascend310_context_nullptr_FAILED) {
EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr));
}
TEST_F(TestCxxApiContext, test_context_ascend310_context_wrong_type_SUCCESS) {
auto ctx = std::make_shared<ModelContext>();
ctx->params["mindspore.option.op_select_impl_mode"] = 5;
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");
}
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
auto ctx = std::make_shared<ModelContext>();
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");