forked from mindspore-Ecosystem/mindspore
api support dual abi
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
674bca3fe4
commit
504f215800
|
@ -16,11 +16,11 @@
|
|||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
#include <map>
|
||||
#include <any>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kDeviceTypeAscend310 = "Ascend310";
|
||||
|
@ -28,38 +28,108 @@ constexpr auto kDeviceTypeAscend910 = "Ascend910";
|
|||
constexpr auto kDeviceTypeGPU = "GPU";
|
||||
|
||||
struct MS_API Context {
|
||||
public:
|
||||
Context();
|
||||
virtual ~Context() = default;
|
||||
std::map<std::string, std::any> params;
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data;
|
||||
};
|
||||
|
||||
struct MS_API GlobalContext : public Context {
|
||||
public:
|
||||
static std::shared_ptr<Context> GetGlobalContext();
|
||||
|
||||
static void SetGlobalDeviceTarget(const std::string &device_target);
|
||||
static std::string GetGlobalDeviceTarget();
|
||||
static inline void SetGlobalDeviceTarget(const std::string &device_target);
|
||||
static inline std::string GetGlobalDeviceTarget();
|
||||
|
||||
static void SetGlobalDeviceID(const uint32_t &device_id);
|
||||
static uint32_t GetGlobalDeviceID();
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
|
||||
static std::vector<char> GetGlobalDeviceTargetChar();
|
||||
};
|
||||
|
||||
struct MS_API ModelContext : public Context {
|
||||
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
||||
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
|
||||
public:
|
||||
static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
||||
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
|
||||
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
|
||||
static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
|
||||
static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
|
||||
static std::string GetInputShape(const std::shared_ptr<Context> &context);
|
||||
static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
|
||||
static inline std::string GetInputShape(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
|
||||
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
|
||||
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
|
||||
static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
|
||||
static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
|
||||
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
|
||||
static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode);
|
||||
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
|
||||
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
|
||||
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
|
||||
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
|
||||
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::vector<char> &op_select_impl_mode);
|
||||
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
|
||||
};
|
||||
|
||||
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
||||
SetGlobalDeviceTarget(StringToChar(device_target));
|
||||
}
|
||||
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
|
||||
|
||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
||||
SetInsertOpConfigPath(context, StringToChar(cfg_path));
|
||||
}
|
||||
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInsertOpConfigPathChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
||||
SetInputFormat(context, StringToChar(format));
|
||||
}
|
||||
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInputFormatChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
||||
SetInputShape(context, StringToChar(shape));
|
||||
}
|
||||
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInputShapeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
||||
SetPrecisionMode(context, StringToChar(precision_mode));
|
||||
}
|
||||
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetPrecisionModeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode) {
|
||||
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
|
||||
}
|
||||
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetOpSelectImplModeChar(context));
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
||||
#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); }
|
||||
inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
|
@ -17,7 +17,6 @@
|
|||
#define MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelImpl;
|
||||
|
@ -46,10 +47,16 @@ class MS_API Model {
|
|||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
return CheckModelSupport(StringToChar(device_type), model_type);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -24,16 +24,24 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
|
||||
static Graph LoadModel(const std::string &file, ModelType model_type);
|
||||
inline static Graph LoadModel(const std::string &file, ModelType model_type);
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
|
||||
private:
|
||||
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
|
||||
};
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
return LoadModel(StringToChar(file), model_type);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
|
|
@ -16,9 +16,13 @@
|
|||
#ifndef MINDSPORE_INCLUDE_API_STATUS_H
|
||||
#define MINDSPORE_INCLUDE_API_STATUS_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <ostream>
|
||||
#include <climits>
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum CompCode : uint32_t {
|
||||
|
@ -100,46 +104,61 @@ enum StatusCode : uint32_t {
|
|||
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
|
||||
};
|
||||
|
||||
class Status {
|
||||
class MS_API Status {
|
||||
public:
|
||||
Status() : status_code_(kSuccess), line_of_code_(-1) {}
|
||||
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
|
||||
: status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
|
||||
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
|
||||
Status();
|
||||
inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit)
|
||||
inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
|
||||
|
||||
~Status() = default;
|
||||
|
||||
enum StatusCode StatusCode() const { return status_code_; }
|
||||
const std::string &ToString() const { return status_msg_; }
|
||||
enum StatusCode StatusCode() const;
|
||||
inline std::string ToString() const;
|
||||
|
||||
int GetLineOfCode() const { return line_of_code_; }
|
||||
const std::string &GetErrDescription() const { return status_msg_; }
|
||||
const std::string &SetErrDescription(const std::string &err_description);
|
||||
int GetLineOfCode() const;
|
||||
inline std::string GetErrDescription() const;
|
||||
inline std::string SetErrDescription(const std::string &err_description);
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const Status &s);
|
||||
|
||||
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
|
||||
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
|
||||
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
|
||||
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
|
||||
bool operator==(const Status &other) const;
|
||||
bool operator==(enum StatusCode other_code) const;
|
||||
bool operator!=(const Status &other) const;
|
||||
bool operator!=(enum StatusCode other_code) const;
|
||||
|
||||
explicit operator bool() const { return (status_code_ == kSuccess); }
|
||||
explicit operator int() const { return static_cast<int>(status_code_); }
|
||||
explicit operator bool() const;
|
||||
explicit operator int() const;
|
||||
|
||||
static Status OK() { return Status(StatusCode::kSuccess); }
|
||||
static Status OK();
|
||||
|
||||
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
|
||||
bool IsOk() const;
|
||||
|
||||
bool IsError() const { return !IsOk(); }
|
||||
bool IsError() const;
|
||||
|
||||
static std::string CodeAsString(enum StatusCode c);
|
||||
static inline std::string CodeAsString(enum StatusCode c);
|
||||
|
||||
private:
|
||||
enum StatusCode status_code_;
|
||||
std::string status_msg_;
|
||||
int line_of_code_;
|
||||
std::string file_name_;
|
||||
std::string err_description_;
|
||||
// api without std::string
|
||||
explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg);
|
||||
Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
|
||||
std::vector<char> ToCString() const;
|
||||
std::vector<char> GetErrDescriptionChar() const;
|
||||
std::vector<char> SetErrDescription(const std::vector<char> &err_description);
|
||||
static std::vector<char> CodeAsCString(enum StatusCode c);
|
||||
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
Status::Status(enum StatusCode status_code, const std::string &status_msg)
|
||||
: Status(status_code, StringToChar(status_msg)) {}
|
||||
Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra)
|
||||
: Status(code, line_of_code, file_name, StringToChar(extra)) {}
|
||||
std::string Status::ToString() const { return CharToString(ToCString()); }
|
||||
std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); }
|
||||
std::string Status::SetErrDescription(const std::string &err_description) {
|
||||
return CharToString(SetErrDescription(StringToChar(err_description)));
|
||||
}
|
||||
std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_STATUS_H
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define MS_API __declspec(dllexport)
|
||||
|
@ -42,18 +43,18 @@ class MS_API MSTensor {
|
|||
public:
|
||||
class Impl;
|
||||
|
||||
static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
|
||||
MSTensor();
|
||||
explicit MSTensor(const std::shared_ptr<Impl> &impl);
|
||||
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
~MSTensor();
|
||||
|
||||
const std::string &Name() const;
|
||||
inline std::string Name() const;
|
||||
enum DataType DataType() const;
|
||||
const std::vector<int64_t> &Shape() const;
|
||||
int64_t ElementNum() const;
|
||||
|
@ -68,6 +69,15 @@ class MS_API MSTensor {
|
|||
bool operator==(std::nullptr_t) const;
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
std::vector<char> CharName() const;
|
||||
|
||||
friend class ModelImpl;
|
||||
explicit MSTensor(std::nullptr_t);
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
@ -92,5 +102,21 @@ class MS_API Buffer {
|
|||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
return CreateTensor(StringToChar(name), type, shape, data, data_len);
|
||||
}
|
||||
|
||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
|
||||
}
|
||||
|
||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
|
||||
|
||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/context.h"
|
||||
#include <any>
|
||||
#include <map>
|
||||
#include <type_traits>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
|
||||
|
@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
|
|||
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
|
||||
|
||||
namespace mindspore {
|
||||
template <class T>
|
||||
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
||||
auto iter = context->params.find(key);
|
||||
if (iter == context->params.end()) {
|
||||
return T();
|
||||
struct Context::Data {
|
||||
std::map<std::string, std::any> params;
|
||||
};
|
||||
|
||||
Context::Context() : data(std::make_shared<Data>()) {}
|
||||
|
||||
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
||||
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
||||
static U empty_result;
|
||||
if (context == nullptr || context->data == nullptr) {
|
||||
return empty_result;
|
||||
}
|
||||
auto iter = context->data->params.find(key);
|
||||
if (iter == context->data->params.end()) {
|
||||
return empty_result;
|
||||
}
|
||||
const std::any &value = iter->second;
|
||||
if (value.type() != typeid(T)) {
|
||||
return T();
|
||||
if (value.type() != typeid(U)) {
|
||||
return empty_result;
|
||||
}
|
||||
|
||||
return std::any_cast<T>(value);
|
||||
return std::any_cast<const U &>(value);
|
||||
}
|
||||
|
||||
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
||||
|
@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
|||
return g_context;
|
||||
}
|
||||
|
||||
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
||||
void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
global_context->params[kGlobalContextDeviceTarget] = device_target;
|
||||
if (global_context->data == nullptr) {
|
||||
global_context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(global_context->data);
|
||||
}
|
||||
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
|
||||
}
|
||||
|
||||
std::string GlobalContext::GetGlobalDeviceTarget() {
|
||||
std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
||||
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
global_context->params[kGlobalContextDeviceID] = device_id;
|
||||
if (global_context->data == nullptr) {
|
||||
global_context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(global_context->data);
|
||||
}
|
||||
global_context->data->params[kGlobalContextDeviceID] = device_id;
|
||||
}
|
||||
|
||||
uint32_t GlobalContext::GetGlobalDeviceID() {
|
||||
|
@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
|
|||
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
|
||||
}
|
||||
|
||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
||||
std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInputFormat] = format;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionInputFormat] = CharToString(format);
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
||||
std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInputFormat);
|
||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInputShape] = shape;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionInputShape] = CharToString(shape);
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
||||
std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInputShape);
|
||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionOutputType] = output_type;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionOutputType] = output_type;
|
||||
}
|
||||
|
||||
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
|
||||
|
@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex
|
|||
return GetValue<enum DataType>(context, kModelOptionOutputType);
|
||||
}
|
||||
|
||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionPrecisionMode] = precision_mode;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
|
||||
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
||||
std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionPrecisionMode);
|
||||
const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode) {
|
||||
const std::vector<char> &op_select_impl_mode) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
|
||||
if (context->data == nullptr) {
|
||||
context->data = std::make_shared<Data>();
|
||||
MS_EXCEPTION_IF_NULL(context->data);
|
||||
}
|
||||
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||
}
|
||||
|
||||
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
||||
std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
||||
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -68,12 +68,13 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
|
|||
|
||||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) {
|
||||
bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) {
|
||||
std::string device_type_str = CharToString(device_type);
|
||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto first_iter = kSupportedModelMap.find(device_type);
|
||||
auto first_iter = kSupportedModelMap.find(device_type_str);
|
||||
if (first_iter == kSupportedModelMap.end()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -84,17 +84,18 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
|
|||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
||||
}
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
|
||||
std::string file_path = CharToString(file);
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = LoadMindIR(file);
|
||||
FuncGraphPtr anf_graph = LoadMindIR(file_path);
|
||||
if (anf_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Load model failed.";
|
||||
}
|
||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
} else if (model_type == kOM) {
|
||||
Buffer data = ReadFile(file);
|
||||
Buffer data = ReadFile(file_path);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
|
||||
MS_LOG(EXCEPTION) << "Read file " << file_path << " failed.";
|
||||
}
|
||||
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
}
|
||||
|
|
|
@ -134,10 +134,11 @@ class TensorReferenceImpl : public MSTensor::Impl {
|
|||
std::vector<int64_t> shape_;
|
||||
};
|
||||
|
||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
std::string name_str = CharToString(name);
|
||||
try {
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
|
||||
return MSTensor(impl);
|
||||
} catch (const std::bad_alloc &) {
|
||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||
|
@ -148,10 +149,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
|
|||
}
|
||||
}
|
||||
|
||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
std::string name_str = CharToString(name);
|
||||
try {
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
|
||||
return MSTensor(impl);
|
||||
} catch (const std::bad_alloc &) {
|
||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||
|
@ -165,9 +167,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type,
|
|||
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
|
||||
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
|
||||
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
|
||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
|
||||
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len)
|
||||
: impl_(std::make_shared<TensorDefaultImpl>(CharToString(name), type, shape, data, data_len)) {}
|
||||
MSTensor::~MSTensor() = default;
|
||||
|
||||
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
||||
|
@ -179,9 +181,9 @@ MSTensor MSTensor::Clone() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
const std::string &MSTensor::Name() const {
|
||||
std::vector<char> MSTensor::CharName() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Name();
|
||||
return StringToChar(impl_->Name());
|
||||
}
|
||||
|
||||
enum DataType MSTensor::DataType() const {
|
||||
|
|
|
@ -296,8 +296,7 @@ else()
|
|||
endif()
|
||||
endif()
|
||||
|
||||
add_dependencies(_c_dataengine mindspore_shared_lib)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore_core mindspore_shared_lib)
|
||||
|
||||
if(USE_GLOG)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::glog)
|
||||
|
|
|
@ -686,6 +686,9 @@ class Tensor {
|
|||
/// pointer to the end of the physical data
|
||||
unsigned char *data_end_ = nullptr;
|
||||
|
||||
/// shape for interpretation of YUV image
|
||||
std::vector<uint32_t> yuv_shape_;
|
||||
|
||||
private:
|
||||
friend class DETensor;
|
||||
|
||||
|
|
|
@ -24,16 +24,42 @@
|
|||
#include <sstream>
|
||||
|
||||
namespace mindspore {
|
||||
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) {
|
||||
status_code_ = code;
|
||||
line_of_code_ = line_of_code;
|
||||
file_name_ = std::string(file_name);
|
||||
err_description_ = extra;
|
||||
struct Status::Data {
|
||||
enum StatusCode status_code = kSuccess;
|
||||
std::string status_msg;
|
||||
int line_of_code = -1;
|
||||
std::string file_name;
|
||||
std::string err_description;
|
||||
};
|
||||
|
||||
Status::Status() : data_(std::make_shared<Data>()) {}
|
||||
|
||||
Status::Status(enum StatusCode status_code, const std::vector<char> &status_msg) : data_(std::make_shared<Data>()) {
|
||||
if (data_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_->status_msg = CharToString(status_msg);
|
||||
data_->status_code = status_code;
|
||||
}
|
||||
|
||||
Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra)
|
||||
: data_(std::make_shared<Data>()) {
|
||||
if (data_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
data_->status_code = code;
|
||||
data_->line_of_code = line_of_code;
|
||||
if (file_name != nullptr) {
|
||||
data_->file_name = file_name;
|
||||
}
|
||||
data_->err_description = CharToString(extra);
|
||||
|
||||
std::ostringstream ss;
|
||||
#ifndef ENABLE_ANDROID
|
||||
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". ";
|
||||
if (!extra.empty()) {
|
||||
ss << extra;
|
||||
if (!data_->err_description.empty()) {
|
||||
ss << data_->err_description;
|
||||
}
|
||||
ss << "\n";
|
||||
#endif
|
||||
|
@ -42,10 +68,38 @@ Status::Status(enum StatusCode code, int line_of_code, const char *file_name, co
|
|||
if (file_name != nullptr) {
|
||||
ss << "File : " << file_name << "\n";
|
||||
}
|
||||
status_msg_ = ss.str();
|
||||
data_->status_msg = ss.str();
|
||||
}
|
||||
|
||||
std::string Status::CodeAsString(enum StatusCode c) {
|
||||
enum StatusCode Status::StatusCode() const {
|
||||
if (data_ == nullptr) {
|
||||
return kSuccess;
|
||||
}
|
||||
return data_->status_code;
|
||||
}
|
||||
|
||||
std::vector<char> Status::ToCString() const {
|
||||
if (data_ == nullptr) {
|
||||
return std::vector<char>();
|
||||
}
|
||||
return StringToChar(data_->status_msg);
|
||||
}
|
||||
|
||||
int Status::GetLineOfCode() const {
|
||||
if (data_ == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
return data_->line_of_code;
|
||||
}
|
||||
|
||||
std::vector<char> Status::GetErrDescriptionChar() const {
|
||||
if (data_ == nullptr) {
|
||||
return std::vector<char>();
|
||||
}
|
||||
return StringToChar(data_->status_msg);
|
||||
}
|
||||
|
||||
std::vector<char> Status::CodeAsCString(enum StatusCode c) {
|
||||
static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."},
|
||||
// Core
|
||||
{kCoreFailed, "Common error code."},
|
||||
|
@ -98,7 +152,7 @@ std::string Status::CodeAsString(enum StatusCode c) {
|
|||
{kLiteInferInvalid, "Invalid infer shape before runtime."},
|
||||
{kLiteInputParamInvalid, "Invalid input param by user."}};
|
||||
auto iter = info_map.find(c);
|
||||
return iter == info_map.end() ? "Unknown error" : iter->second;
|
||||
return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const Status &s) {
|
||||
|
@ -106,22 +160,48 @@ std::ostream &operator<<(std::ostream &os, const Status &s) {
|
|||
return os;
|
||||
}
|
||||
|
||||
const std::string &Status::SetErrDescription(const std::string &err_description) {
|
||||
err_description_ = err_description;
|
||||
std::vector<char> Status::SetErrDescription(const std::vector<char> &err_description) {
|
||||
if (data_ == nullptr) {
|
||||
return std::vector<char>();
|
||||
}
|
||||
data_->err_description = CharToString(err_description);
|
||||
std::ostringstream ss;
|
||||
#ifndef ENABLE_ANDROID
|
||||
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(status_code_) << ". ";
|
||||
if (!err_description_.empty()) {
|
||||
ss << err_description_;
|
||||
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". ";
|
||||
if (!data_->err_description.empty()) {
|
||||
ss << data_->err_description;
|
||||
}
|
||||
ss << "\n";
|
||||
#endif
|
||||
|
||||
if (line_of_code_ > 0 && !file_name_.empty()) {
|
||||
ss << "Line of code : " << line_of_code_ << "\n";
|
||||
ss << "File : " << file_name_ << "\n";
|
||||
if (data_->line_of_code > 0 && !data_->file_name.empty()) {
|
||||
ss << "Line of code : " << data_->line_of_code << "\n";
|
||||
ss << "File : " << data_->file_name << "\n";
|
||||
}
|
||||
status_msg_ = ss.str();
|
||||
return status_msg_;
|
||||
data_->status_msg = ss.str();
|
||||
return StringToChar(data_->status_msg);
|
||||
}
|
||||
|
||||
bool Status::operator==(const Status &other) const {
|
||||
if (data_ == nullptr && other.data_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (data_ == nullptr || other.data_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return data_->status_code == other.data_->status_code;
|
||||
}
|
||||
|
||||
bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; }
|
||||
bool Status::operator!=(const Status &other) const { return !operator==(other); }
|
||||
bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); }
|
||||
|
||||
Status::operator bool() const { return (StatusCode() == kSuccess); }
|
||||
Status::operator int() const { return static_cast<int>(StatusCode()); }
|
||||
|
||||
Status Status::OK() { return StatusCode::kSuccess; }
|
||||
bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
|
||||
bool Status::IsError() const { return !IsOk(); }
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -67,7 +67,7 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
|
|||
|
||||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
|
||||
bool Model::CheckModelSupport(const std::vector<char> &, ModelType) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
|
|||
return graph;
|
||||
}
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return Graph(nullptr);
|
||||
}
|
||||
|
|
|
@ -60,16 +60,16 @@ class Buffer::Impl {
|
|||
MSTensor::MSTensor() : impl_(std::make_shared<Impl>()) {}
|
||||
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
|
||||
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {}
|
||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {}
|
||||
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len)
|
||||
: impl_(std::make_shared<Impl>(CharToString(name), type, shape, data, data_len)) {}
|
||||
MSTensor::~MSTensor() = default;
|
||||
|
||||
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
||||
|
||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
auto impl = std::make_shared<Impl>(name, type, shape, data, data_len);
|
||||
auto impl = std::make_shared<Impl>(CharToString(name), type, shape, data, data_len);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||
return MSTensor(nullptr);
|
||||
|
@ -77,7 +77,7 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
|
|||
return MSTensor(impl);
|
||||
}
|
||||
|
||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
auto tensor = CreateTensor(name, type, shape, data, data_len);
|
||||
if (tensor == nullptr) {
|
||||
|
@ -98,13 +98,12 @@ MSTensor MSTensor::Clone() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
const std::string &MSTensor::Name() const {
|
||||
static std::string empty = "";
|
||||
std::vector<char> MSTensor::CharName() const {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid tensor inpmlement.";
|
||||
return empty;
|
||||
return std::vector<char>();
|
||||
}
|
||||
return impl_->Name();
|
||||
return StringToChar(impl_->Name());
|
||||
}
|
||||
|
||||
int64_t MSTensor::ElementNum() const {
|
||||
|
|
|
@ -60,12 +60,6 @@ TEST_F(TestCxxApiContext, test_context_ascend310_context_nullptr_FAILED) {
|
|||
EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr));
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_wrong_type_SUCCESS) {
|
||||
auto ctx = std::make_shared<ModelContext>();
|
||||
ctx->params["mindspore.option.op_select_impl_mode"] = 5;
|
||||
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
||||
auto ctx = std::make_shared<ModelContext>();
|
||||
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");
|
||||
|
|
Loading…
Reference in New Issue