cxx api refactor: tensor/status/model

This commit is contained in:
lixian 2021-01-26 17:06:34 +08:00 committed by zhoufeng
parent f1009cb21b
commit 7d2fd6e76c
227 changed files with 4285 additions and 2786 deletions

View File

@ -65,7 +65,7 @@ install(
install(
TARGETS mindspore_shared_lib
LIBRARY DESTINATION ${INSTALL_LIB_DIR}
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
@ -327,7 +327,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/minddata_eager.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/execute.h
DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include
COMPONENT mindspore
)

View File

@ -109,6 +109,8 @@ if(PLATFORM_ARM64)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
@ -128,6 +130,8 @@ elseif(PLATFORM_ARM32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
@ -162,6 +166,8 @@ elseif(WIN32)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark)
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86}
COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -182,6 +188,8 @@ else()
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}

View File

@ -24,7 +24,6 @@
#include "include/api/graph.h"
namespace mindspore {
namespace api {
class InputAndOutput;
using Input = InputAndOutput;
using Output = InputAndOutput;
@ -35,7 +34,7 @@ class MS_API CellBase {
virtual ~CellBase() = default;
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
virtual std::shared_ptr<CellBase> Clone() const = 0;
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; }
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
};
@ -57,16 +56,16 @@ class MS_API ParameterCell final : public Cell<ParameterCell> {
ParameterCell(ParameterCell &&);
ParameterCell &operator=(ParameterCell &&);
explicit ParameterCell(const Tensor &);
ParameterCell &operator=(const Tensor &);
explicit ParameterCell(const MSTensor &);
ParameterCell &operator=(const MSTensor &);
explicit ParameterCell(Tensor &&);
ParameterCell &operator=(Tensor &&);
explicit ParameterCell(MSTensor &&);
ParameterCell &operator=(MSTensor &&);
Tensor GetTensor() const { return tensor_; }
MSTensor GetTensor() const { return tensor_; }
private:
Tensor tensor_;
MSTensor tensor_;
};
class MS_API OpCellBase : public CellBase {
@ -99,11 +98,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(const std::shared_ptr<Graph> &);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
private:
friend class ModelImpl;
@ -119,8 +116,8 @@ class MS_API InputAndOutput {
~InputAndOutput() = default;
// no explicit
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit)
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
@ -132,6 +129,5 @@ class MS_API InputAndOutput {
std::vector<InputAndOutput> prev_;
int32_t index_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CELL_H

View File

@ -16,26 +16,49 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <map>
#include <any>
#include <string>
#include <memory>
#include "include/api/types.h"
namespace mindspore {
namespace api {
class MS_API Context {
public:
static Context &Instance();
const std::string &GetDeviceTarget() const;
Context &SetDeviceTarget(const std::string &device_target);
uint32_t GetDeviceID() const;
Context &SetDeviceID(uint32_t device_id);
constexpr auto kDeviceTypeAscend310 = "Ascend310";
constexpr auto kDeviceTypeAscend910 = "Ascend910";
private:
Context();
~Context();
class ContextImpl;
std::shared_ptr<ContextImpl> impl_;
struct MS_API Context {
virtual ~Context() = default;
std::map<std::string, std::any> params;
};
struct MS_API GlobalContext : public Context {
static std::shared_ptr<Context> GetGlobalContext();
static void SetGlobalDeviceTarget(const std::string &device_target);
static std::string GetGlobalDeviceTarget();
static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();
};
struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context);
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

43
include/api/data_type.h Normal file
View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_
namespace mindspore {
enum class DataType : int {
kTypeUnknown = 0,
kObjectTypeString = 12,
kObjectTypeList = 13,
kObjectTypeTuple = 14,
kObjectTypeTensorType = 17,
kNumberTypeBool = 30,
kNumberTypeInt8 = 32,
kNumberTypeInt16 = 33,
kNumberTypeInt32 = 34,
kNumberTypeInt64 = 35,
kNumberTypeUInt8 = 37,
kNumberTypeUInt16 = 38,
kNumberTypeUInt32 = 39,
kNumberTypeUInt64 = 40,
kNumberTypeFloat16 = 42,
kNumberTypeFloat32 = 43,
kNumberTypeFloat64 = 44,
kNumberTypeEnd = 46,
// add new enum here
kInvalidType = INT32_MAX,
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
#define MINDSPORE_INCLUDE_API_GRAPH_H
#include <cstddef>
#include <string>
#include <vector>
#include <map>
@ -24,21 +25,21 @@
#include "include/api/types.h"
namespace mindspore {
namespace api {
class MS_API Graph {
public:
class GraphData;
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
explicit Graph(std::nullptr_t);
~Graph();
enum ModelType ModelType() const;
bool operator==(std::nullptr_t) const;
private:
friend class GraphCell;
friend class ModelImpl;
std::shared_ptr<GraphData> graph_data_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_GRAPH_H

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#include <string>
#include <memory>
#include <map>
#include <any>
#include "include/api/types.h"
namespace mindspore {
namespace lite {
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
typedef enum : uint32_t {
NO_BIND = 0, /**< no bind */
HIGHER_CPU = 1, /**< bind higher cpu first */
MID_CPU = 2 /**< bind middle cpu first */
} CpuBindMode;
class Allocator;
} // namespace lite
struct MS_API Context {
public:
static void Clear(const std::shared_ptr<Context> &contxet);
static void SetAsDefault(const std::shared_ptr<Context> &contxet);
static void SetVendorName(const std::shared_ptr<Context> &contxet, const std::string &name);
static std::string GetVendorName(const std::shared_ptr<Context> &contxet);
static void SetThreadNum(const std::shared_ptr<Context> &contxet, int num);
static int GetThreadNum(const std::shared_ptr<Context> &contxet);
static void SetAllocator(const std::shared_ptr<Context> &contxet, std::shared_ptr<lite::Allocator> alloc);
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &contxet);
static void ConfigCPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigCPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void SetCPUBindMode(const std::shared_ptr<Context> &contxet, lite::CpuBindMode mode);
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &contxet);
static void ConfigGPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigGPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void ConfigNPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfNPUEnabled(const std::shared_ptr<Context> &contxet);
static void SetNPUFrequency(const std::shared_ptr<Context> &contxet, int freq);
static int GetNPUFrequency(const std::shared_ptr<Context> &contxet);
private:
std::map<std::string, std::any> context_;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H

View File

@ -20,41 +20,36 @@
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/cell.h"
namespace mindspore {
namespace api {
class ModelImpl;
// todo: minddata c++ interface
class DataSet {};
struct Context;
class MS_API Model {
public:
explicit Model(const std::vector<Output> &network);
explicit Model(const GraphCell &graph);
explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
Status Build(const std::map<std::string, std::string> &options);
Status Build();
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private:
std::shared_ptr<ModelImpl> impl_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -25,7 +25,6 @@
#include "include/api/cell.h"
namespace mindspore {
namespace api {
struct MS_API Conv2D : public OpCell<Conv2D> {
Conv2D() : OpCell("Conv2D") {}
~Conv2D() override = default;
@ -45,6 +44,5 @@ struct MS_API Conv2D : public OpCell<Conv2D> {
std::vector<int> dilation = {1, 1, 1, 1};
int group = 1;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H

View File

@ -26,15 +26,14 @@
#include "include/api/graph.h"
namespace mindspore {
namespace api {
class MS_API Serialization {
public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -17,37 +17,129 @@
#define MINDSPORE_INCLUDE_API_STATUS_H
#include <string>
#include <ostream>
#include <climits>
namespace mindspore {
namespace api {
enum StatusCode {
SUCCESS = 0,
FAILED,
INVALID_INPUTS,
// insert new status code here
UNKNOWN = 0xFFFFFFFF
enum CompCode : uint32_t {
kCore = 0x00000000u,
kMD = 0x10000000u,
kME = 0x20000000u,
kMC = 0x30000000u,
kLite = 0xF0000000u,
};
enum StatusCode : uint32_t {
kSuccess = 0,
// Core
kCoreFailed = kCore | 0x1,
// MD
kMDOutOfMemory = kMD | 1,
kMDShapeMisMatch = kMD | 2,
kMDInterrupted = kMD | 3,
kMDNoSpace = kMD | 4,
kMDPyFuncException = kMD | 5,
kMDDuplicateKey = kMD | 6,
kMDPythonInterpreterFailure = kMD | 7,
kMDTDTPushFailure = kMD | 8,
kMDFileNotExist = kMD | 9,
kMDProfilingError = kMD | 10,
kMDBoundingBoxOutOfBounds = kMD | 11,
kMDBoundingBoxInvalidShape = kMD | 12,
kMDSyntaxError = kMD | 13,
kMDTimeOut = kMD | 14,
kMDBuddySpaceFull = kMD | 15,
kMDNetWorkError = kMD | 16,
kMDNotImplementedYet = kMD | 17,
// Make this error code the last one. Add new error code above it.
kMDUnexpectedError = kMD | 127,
// ME
kMEFailed = kME | 0x1,
kMEInvalidInput = kME | 0x2,
// MC
kMCFailed = kMC | 0x1,
kMCDeviceError = kMC | 0x2,
kMCInvalidInput = kMC | 0x3,
kMCInvalidArgs = kMC | 0x4,
// Lite // Common error code, range: [-1, -100
kLiteError = kLite | (0x0FFFFFFF & -1), /**< Common error code. */
kLiteNullptr = kLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/
kLiteParamInvalid = kLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/
kLiteNoChange = kLite | (0x0FFFFFFF & -4), /**< No change. */
kLiteSuccessExit = kLite | (0x0FFFFFFF & -5), /**< No error but exit. */
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
// Executor error code, range: [-100,-200)
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
// Graph error code, range: [-200,-300)
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
// Node error code, range: [-300,-400)
kLiteNotFindOp = kLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */
kLiteInvalidOpName = kLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */
kLiteInvalidOpAttr = kLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */
kLiteOpExecuteFailure = kLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */
// Tensor error code, range: [-400,-500)
kLiteFormatError = kLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */
// InferShape error code, range: [-500,-600)
kLiteInferError = kLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */
kLiteInferInvalid = kLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */
// User input param error code, range: [-600, 700)
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
};
class Status {
public:
Status() : status_code_(FAILED) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg) {}
Status() : status_code_(kSuccess), line_of_code_(-1) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
~Status() = default;
bool IsSuccess() const { return status_code_ == SUCCESS; }
enum StatusCode StatusCode() const { return status_code_; }
std::string StatusMessage() const { return status_msg_; }
const std::string &ToString() const { return status_msg_; }
int GetLineOfCode() const { return line_of_code_; }
const std::string &GetErrDescription() const { return status_msg_; }
const std::string &SetErrDescription(const std::string &err_description);
friend std::ostream &operator<<(std::ostream &os, const Status &s);
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
operator bool() const = delete;
explicit operator bool() const { return (status_code_ == kSuccess); }
explicit operator int() const { return static_cast<int>(status_code_); }
static Status OK() { return Status(StatusCode::kSuccess); }
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool IsError() const { return !IsOk(); }
static std::string CodeAsString(enum StatusCode c);
private:
enum StatusCode status_code_;
std::string status_msg_;
int line_of_code_;
std::string file_name_;
std::string err_description_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H

View File

@ -16,15 +16,20 @@
#ifndef MINDSPORE_INCLUDE_API_TYPES_H
#define MINDSPORE_INCLUDE_API_TYPES_H
#include <cstddef>
#include <string>
#include <vector>
#include <memory>
#include "include/api/data_type.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore {
namespace api {
enum ModelType {
enum ModelType : uint32_t {
kMindIR = 0,
kAIR = 1,
kOM = 2,
@ -33,52 +38,38 @@ enum ModelType {
kUnknownType = 0xFFFFFFFF
};
enum DataType {
kMsUnknown = 0,
kMsBool = 1,
kMsInt8 = 2,
kMsInt16 = 3,
kMsInt32 = 4,
kMsInt64 = 5,
kMsUint8 = 6,
kMsUint16 = 7,
kMsUint32 = 8,
kMsUint64 = 9,
kMsFloat16 = 10,
kMsFloat32 = 11,
kMsFloat64 = 12,
// insert new data type here
kInvalidDataType = 0xFFFFFFFF
};
class MS_API Tensor {
class MS_API MSTensor {
public:
Tensor();
Tensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, size_t data_len);
~Tensor();
class Impl;
static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
~MSTensor();
const std::string &Name() const;
void SetName(const std::string &name);
api::DataType DataType() const;
void SetDataType(api::DataType type);
enum DataType DataType() const;
const std::vector<int64_t> &Shape() const;
void SetShape(const std::vector<int64_t> &shape);
int64_t ElementNum() const;
const void *Data() const;
std::shared_ptr<const void> Data() const;
void *MutableData();
size_t DataSize() const;
bool ResizeData(size_t data_len);
bool SetData(const void *data, size_t data_len);
bool IsDevice() const;
int64_t ElementNum() const;
static int GetTypeSize(api::DataType type);
Tensor Clone() const;
MSTensor Clone() const;
bool operator==(std::nullptr_t) const;
private:
class Impl;
friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_;
};
@ -101,21 +92,5 @@ class MS_API Buffer {
class Impl;
std::shared_ptr<Impl> impl_;
};
extern MS_API const char *kDeviceTypeAscend310;
extern MS_API const char *kDeviceTypeAscend910;
extern MS_API const char *kDeviceTypeGpu;
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
// "high_precision" or "high_performance", default as "high_performance"
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H

View File

@ -23,7 +23,7 @@ if(ENABLE_D)
endif()
if(ENABLE_GPU)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc")
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc")
endif()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
@ -45,8 +45,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf)
else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
if(ENABLE_D OR ENABLE_ACL)
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
mindspore proto_input mindspore_gvar mindspore::protobuf)
endif()
endif()
if(ENABLE_CPU)

View File

@ -18,7 +18,7 @@
#include "cxx_api/factory.h"
#include "cxx_api/graph/graph_impl.h"
namespace mindspore::api {
namespace mindspore {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
@ -40,23 +40,23 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
return *this;
}
ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {}
ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
ParameterCell &ParameterCell::operator=(const Tensor &tensor) {
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
tensor_ = tensor.Clone();
return *this;
}
ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {}
ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
ParameterCell &ParameterCell::operator=(Tensor &&tensor) {
ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
tensor_ = tensor;
return *this;
}
GraphCell::GraphCell(const Graph &graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
@ -64,7 +64,7 @@ GraphCell::GraphCell(const Graph &graph)
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
: graph_(graph),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
@ -72,13 +72,13 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
GraphCell::GraphCell(Graph &&graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs);
}
@ -88,25 +88,24 @@ Status GraphCell::Load() {
return executor_->Load();
}
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> GraphCell::GetInputs() {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes);
return executor_->GetInputs();
}
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> GraphCell::GetOutputs() {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
return executor_->GetOutputs();
}
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const Tensor &tensor)
InputAndOutput::InputAndOutput(const MSTensor &tensor)
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(MSTensor &&tensor)
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
int32_t index)
: cell_(cell), prev_(prev), index_(index) {}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -16,49 +16,119 @@
#include "include/api/context.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
class Context::ContextImpl {
public:
ContextImpl() : device_target_("NotSet"), device_id_(0) {}
~ContextImpl() = default;
const std::string &GetDeviceTarget() const { return device_target_; }
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; }
uint32_t GetDeviceID() const { return device_id_; }
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
private:
std::string device_target_;
uint32_t device_id_;
};
namespace mindspore {
template <class T>
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
auto iter = context->params.find(key);
if (iter == context->params.end()) {
return T();
}
const std::any &value = iter->second;
if (value.type() != typeid(T)) {
return T();
}
Context &Context::Instance() {
static Context context;
return context;
return std::any_cast<T>(value);
}
const std::string &Context::GetDeviceTarget() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetDeviceTarget();
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
static std::shared_ptr<Context> g_context = std::make_shared<Context>();
return g_context;
}
Context &Context::SetDeviceTarget(const std::string &device_target) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDeviceTarget(device_target);
return *this;
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceTarget] = device_target;
}
uint32_t Context::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetDeviceID();
std::string GlobalContext::GetGlobalDeviceTarget() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
}
Context &Context::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDeviceID(device_id);
return *this;
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceID] = device_id;
}
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); }
uint32_t GlobalContext::GetGlobalDeviceID() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}
Context::~Context() {}
} // namespace mindspore::api
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format;
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat);
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape;
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape);
}
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type;
}
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<enum DataType>(context, kModelOptionOutputType);
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode;
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode);
}
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
}
} // namespace mindspore

View File

@ -23,7 +23,7 @@
#include <utility>
#include "utils/utils.h"
namespace mindspore::api {
namespace mindspore {
template <class T>
class Factory {
using U = std::function<std::shared_ptr<T>()>;
@ -79,5 +79,5 @@ class Registrar {
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

View File

@ -17,8 +17,8 @@
#include "utils/log_adapter.h"
#include "acl/acl.h"
namespace mindspore::api {
std::weak_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
namespace mindspore {
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
@ -42,7 +42,7 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
std::shared_ptr<AclEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
acl_env = global_acl_env_.lock();
acl_env = global_acl_env_;
if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip.";
} else {
@ -57,4 +57,4 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
}
return acl_env;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -20,7 +20,7 @@
#include <mutex>
#include "acl/acl_base.h"
namespace mindspore::api {
namespace mindspore {
class __attribute__((visibility("default"))) AclEnvGuard {
public:
explicit AclEnvGuard(std::string_view cfg_file);
@ -29,10 +29,10 @@ class __attribute__((visibility("default"))) AclEnvGuard {
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
private:
static std::weak_ptr<AclEnvGuard> global_acl_env_;
static std::shared_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_;
aclError errno_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H

View File

@ -16,53 +16,50 @@
#include "cxx_api/graph/acl/acl_graph_impl.h"
#include "include/api/context.h"
#include "cxx_api/model/acl/model_converter.h"
#include "cxx_api/python_utils.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
AclGraphImpl::AclGraphImpl()
: init_flag_(false),
load_flag_(false),
device_type_("AscendCL"),
device_id_(Context::Instance().GetDeviceID()),
device_id_(GlobalContext::GetGlobalDeviceID()),
context_(nullptr),
acl_env_(nullptr) {}
AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED;
return ret;
}
return model_process_.PredictFromHost(inputs, outputs);
}
Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> AclGraphImpl::GetInputs() {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED;
return {};
}
return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes);
return model_process_.GetInputs();
}
Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> AclGraphImpl::GetOutputs() {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED;
return {};
}
return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes);
return model_process_.GetOutputs();
}
Status AclGraphImpl::LoadAclModel(Buffer om_data) {
@ -72,44 +69,44 @@ Status AclGraphImpl::LoadAclModel(Buffer om_data) {
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
return FAILED;
return kMCDeviceError;
}
// acl init model resource
model_process_.set_model_id(acl_model_id);
Status ret = model_process_.PreInitModelResource();
if (ret != SUCCESS) {
if (ret != kSuccess) {
(void)aclmdlUnload(acl_model_id);
MS_LOG(ERROR) << "Pre init model resource failed.";
return FAILED;
return ret;
}
MS_LOG(INFO) << "Load acl model success.";
return SUCCESS;
return kSuccess;
}
Status AclGraphImpl::InitEnv() {
if (init_flag_) {
return SUCCESS;
return kSuccess;
}
acl_env_ = AclEnvGuard::GetAclEnv("");
if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed.";
return FAILED;
return kMCDeviceError;
}
aclError ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
return FAILED;
return kMCDeviceError;
}
MS_LOG(INFO) << "Open device " << device_id_ << " success";
ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl create context failed";
return FAILED;
return kMCDeviceError;
}
MS_LOG(INFO) << "Create context success";
@ -117,7 +114,7 @@ Status AclGraphImpl::InitEnv() {
ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl get run mode failed";
return FAILED;
return kMCDeviceError;
}
bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device);
@ -125,24 +122,24 @@ Status AclGraphImpl::InitEnv() {
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
init_flag_ = true;
return SUCCESS;
return kSuccess;
}
Status AclGraphImpl::FinalizeEnv() {
if (!init_flag_) {
return SUCCESS;
return kSuccess;
}
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED;
return kMCDeviceError;
}
Status ret = model_process_.UnLoad();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Unload model inner failed.";
return FAILED;
return ret;
}
if (context_ != nullptr) {
@ -161,16 +158,16 @@ Status AclGraphImpl::FinalizeEnv() {
MS_LOG(INFO) << "End to reset device " << device_id_;
init_flag_ = false;
return SUCCESS;
return kSuccess;
}
Status AclGraphImpl::Load() {
// check graph type
if (graph_->ModelType() != ModelType::kOM) {
Status ret = ConvertToOM();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load Failed.";
return FAILED;
return ret;
}
}
@ -180,15 +177,15 @@ Status AclGraphImpl::Load() {
// init
Status ret = InitEnv();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";
return FAILED;
return ret;
}
// load model
if (!load_flag_) {
ret = LoadAclModel(om_data);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load acl model failed.";
return ret;
}
@ -198,24 +195,24 @@ Status AclGraphImpl::Load() {
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED;
return kMCDeviceError;
}
return SUCCESS;
return kSuccess;
}
Status AclGraphImpl::ConvertToOM() {
MS_LOG(INFO) << "Start convert to om model.";
if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid graph_ is null.";
return FAILED;
return kMCFailed;
}
auto &graph_data = GraphImpl::MutableGraphData();
MS_EXCEPTION_IF_NULL(graph_data);
if (graph_->ModelType() == ModelType::kOM) {
MS_LOG(INFO) << "This model has been built, skip.";
return SUCCESS;
return kSuccess;
} else if (graph_->ModelType() == ModelType::kMindIR) {
auto func_graph = graph_data->GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -223,13 +220,13 @@ Status AclGraphImpl::ConvertToOM() {
Buffer om_data = model_converter.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Convert MindIR to OM failed.";
return FAILED;
return kMCFailed;
}
graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM);
MS_LOG(INFO) << "Convert MindIR to OM success.";
return SUCCESS;
return kSuccess;
}
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
return FAILED;
return kMCFailed;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -27,18 +27,16 @@
#include "cxx_api/graph/graph_impl.h"
#include "cxx_api/factory.h"
namespace mindspore::api {
namespace mindspore {
class AclGraphImpl : public GraphCell::GraphImpl {
public:
AclGraphImpl();
~AclGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
Status ConvertToOM();
@ -56,5 +54,5 @@ class AclGraphImpl : public GraphCell::GraphImpl {
ModelProcess model_process_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H

View File

@ -20,17 +20,19 @@
#include <map>
#include "utils/utils.h"
namespace mindspore::api {
namespace mindspore {
static DataType TransToApiType(aclDataType data_type) {
static const std::map<aclDataType, api::DataType> data_type_map = {
{ACL_FLOAT16, api::kMsFloat16}, {ACL_FLOAT, api::kMsFloat32}, {ACL_DOUBLE, api::kMsFloat64},
{ACL_INT8, api::kMsInt8}, {ACL_INT16, api::kMsInt16}, {ACL_INT32, api::kMsInt32},
{ACL_INT64, api::kMsInt64}, {ACL_UINT8, api::kMsUint8}, {ACL_UINT16, api::kMsUint16},
{ACL_UINT32, api::kMsUint32}, {ACL_UINT64, api::kMsUint64}, {ACL_BOOL, api::kMsBool},
static const std::map<aclDataType, enum DataType> data_type_map = {
{ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32},
{ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8},
{ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32},
{ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8},
{ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32},
{ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool},
};
auto it = data_type_map.find(data_type);
if (it == data_type_map.end()) {
return api::kInvalidDataType;
return DataType::kTypeUnknown;
} else {
return it->second;
}
@ -51,7 +53,7 @@ inline static void PushbackIfNotNull(U *vec, T &&item) {
}
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types,
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
std::vector<size_t> *mem_sizes) {
ClearIfNotNull(names);
ClearIfNotNull(shapes);
@ -66,41 +68,69 @@ static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_lis
}
}
static std::string ShapeToString(const std::vector<int64_t> &shape) {
std::string result = "[";
for (size_t i = 0; i < shape.size(); ++i) {
result += std::to_string(shape[i]);
if (i + 1 < shape.size()) {
result += ", ";
}
}
result += "]";
return result;
}
Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list,
std::vector<MSTensor> *tensor_list) {
MS_EXCEPTION_IF_NULL(tensor_list);
std::vector<std::string> names;
std::vector<std::vector<int64_t>> shapes;
std::vector<enum DataType> data_types;
std::vector<size_t> mem_sizes;
ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types, &mem_sizes);
tensor_list->clear();
if (names.size() != acl_tensor_list.size() || shapes.size() != acl_tensor_list.size() ||
data_types.size() != acl_tensor_list.size() || mem_sizes.size() != acl_tensor_list.size()) {
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names.size() << " shapes size " << shapes.size()
<< " data types size " << data_types.size() << " mem sizes size " << mem_sizes.size()
<< " acl_tensor_list size " << acl_tensor_list.size();
return kMCFailed;
}
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]);
auto ret = aclrtMemcpy((*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(),
acl_tensor_list[i].device_data, acl_tensor_list[i].buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << acl_tensor_list[i].buffer_size;
return kMCFailed;
}
}
return kSuccess;
}
Status ModelProcess::PreInitModelResource() {
model_desc_ = aclmdlCreateDesc();
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Read model desc failed";
return FAILED;
return kMCDeviceError;
}
Status ret = InitInputsBuffer();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Create input buffer failed";
return FAILED;
return ret;
}
ret = InitOutputsBuffer();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Create output buffer failed";
return FAILED;
return ret;
}
return SUCCESS;
}
Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t *model_id) {
MS_EXCEPTION_IF_NULL(model_id);
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), model_id);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name;
return FAILED;
}
MS_LOG(INFO) << "Load model success " << file_name;
model_id_ = *model_id;
if (PreInitModelResource() != SUCCESS) {
aclmdlUnload(model_id_);
MS_LOG(ERROR) << "Pre init model resource failed, file name is " << file_name;
return FAILED;
}
return SUCCESS;
return kSuccess;
}
Status ModelProcess::InitInputsBuffer() {
@ -113,8 +143,8 @@ Status ModelProcess::InitInputsBuffer() {
if (!is_run_on_device_) { // need to copy input/output to/from device
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device input buffer faild , input size " << buffer_size;
return FAILED;
MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size;
return kMCDeviceError;
}
}
@ -125,7 +155,7 @@ Status ModelProcess::InitInputsBuffer() {
if (!is_run_on_device_) {
aclrtFree(data_mem_buffer);
}
return FAILED;
return kMCDeviceError;
}
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
@ -137,7 +167,7 @@ Status ModelProcess::InitInputsBuffer() {
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name});
}
MS_LOG(INFO) << "Create model inputs success";
return SUCCESS;
return kSuccess;
}
Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
@ -154,14 +184,14 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
if (!is_run_on_device_) {
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
return FAILED;
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
return kMCDeviceError;
}
} else {
ret = aclrtMallocHost(data_mem_buffer, buffer_size);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
return FAILED;
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
return kMCDeviceError;
}
}
@ -169,16 +199,16 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed";
free_data_buffer(*data_mem_buffer);
return FAILED;
return kMCDeviceError;
}
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed";
free_data_buffer(*data_mem_buffer);
aclDestroyDataBuffer(data_buffer);
return FAILED;
return kMCDeviceError;
}
return SUCCESS;
return kSuccess;
}
Status ModelProcess::InitOutputsBuffer() {
@ -186,7 +216,7 @@ Status ModelProcess::InitOutputsBuffer() {
outputs_ = aclmdlCreateDataset();
if (outputs_ == nullptr) {
MS_LOG(ERROR) << "Create input dataset failed";
return FAILED;
return kMCDeviceError;
}
size_t output_size = aclmdlGetNumOutputs(model_desc_);
MS_LOG(INFO) << "output_size = " << output_size;
@ -194,9 +224,9 @@ Status ModelProcess::InitOutputsBuffer() {
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void *data_mem_buffer = nullptr;
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != SUCCESS) {
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != kSuccess) {
MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size;
return FAILED;
return kMCDeviceError;
}
aclmdlIODims dims;
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
@ -207,7 +237,7 @@ Status ModelProcess::InitOutputsBuffer() {
} else {
aclrtFreeHost(data_mem_buffer);
}
return FAILED;
return kMCDeviceError;
}
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
@ -219,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() {
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name});
}
MS_LOG(INFO) << "Create model output success";
return SUCCESS;
return kSuccess;
}
void ModelProcess::DestroyInputsDataset() {
@ -273,50 +303,60 @@ Status ModelProcess::UnLoad() {
auto ret = aclmdlUnload(model_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed";
return FAILED;
return kMCDeviceError;
}
if (model_desc_ != nullptr) {
ret = aclmdlDestroyDesc(model_desc_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed";
return FAILED;
return kMCDeviceError;
}
model_desc_ = nullptr;
}
DestroyInputsBuffer();
DestroyOutputsBuffer();
MS_LOG(INFO) << "End unload model " << model_id_;
return SUCCESS;
return kSuccess;
}
Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
aclError ret;
inputs_ = aclmdlCreateDataset();
// check inputs
if (inputs.size() != input_infos_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given count "
MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count "
<< inputs.size();
return INVALID_INPUTS;
return kMCInvalidInput;
}
for (size_t i = 0; i < input_infos_.size(); ++i) {
if (inputs[i].Shape() != input_infos_[i].dims) {
MS_LOG(INFO) << "Note: input " << i << " shape not match, required " << ShapeToString(input_infos_[i].dims)
<< ", given " << ShapeToString(inputs[i].Shape());
}
if (inputs[i].DataType() != TransToApiType(input_infos_[i].data_type)) {
MS_LOG(INFO) << "Note: input " << i << " data type not match, required "
<< TransToApiType(input_infos_[i].data_type) << ", given " << inputs[i].DataType();
}
if (inputs[i].DataSize() != input_infos_[i].buffer_size) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << input_infos_[i].buffer_size
<< ", given count " << inputs[i].DataSize();
return INVALID_INPUTS;
return kMCInvalidInput;
}
}
// copy inputs
for (size_t i = 0; i < input_infos_.size(); ++i) {
const auto &info = input_infos_[i];
const auto &input = inputs[i];
const void *data = input.Data();
auto input = inputs[i];
const void *data = input.MutableData();
void *input_buffer = nullptr;
if (!is_run_on_device_) {
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
return FAILED;
return kMCDeviceError;
}
input_buffer = info.device_data;
} else {
@ -325,23 +365,23 @@ Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed";
return FAILED;
return kMCDeviceError;
}
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed";
aclDestroyDataBuffer(data_buffer);
return FAILED;
return kMCDeviceError;
}
}
return SUCCESS;
return kSuccess;
}
Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status ModelProcess::PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
aclError acl_ret;
Status ret = CheckAndInitInput(inputs);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "check or init input failed";
DestroyInputsDataset();
return ret; // forward status error
@ -361,50 +401,48 @@ Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vec
DestroyInputsDataset();
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
return kMCDeviceError;
}
ret = BuildOutputs(outputs);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Build outputs faield";
return FAILED;
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build outputs failed";
return ret;
}
MS_LOG(INFO) << "excute model success";
return SUCCESS;
MS_LOG(INFO) << "Execute model success";
return kSuccess;
}
Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) {
Status ModelProcess::BuildOutputs(std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
aclError ret;
// copy outputs
outputs->clear();
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < output_infos_.size(); ++i) {
const auto &info = output_infos_[i];
outputs->emplace_back(Buffer());
auto output = outputs->rbegin();
if (!output->ResizeData(info.buffer_size)) {
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
return FAILED;
}
ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << info.buffer_size;
return FAILED;
}
auto inner_outputs = GetOutputs();
if (inner_outputs.size() != output_infos_.size()) {
MS_LOG(ERROR) << "Invalid inner outputs size " << inner_outputs.size() << " do not match device output infos size "
<< output_infos_.size();
return kMCFailed;
}
return SUCCESS;
(*outputs) = inner_outputs;
return kSuccess;
}
Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes);
return SUCCESS;
std::vector<MSTensor> ModelProcess::GetInputs() {
Status ret = ConstructTensors(input_infos_, &input_tensors_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "ConstructTensors failed.";
input_tensors_.clear();
}
return input_tensors_;
}
Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes);
return SUCCESS;
std::vector<MSTensor> ModelProcess::GetOutputs() {
Status ret = ConstructTensors(output_infos_, &output_tensors_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "ConstructTensors failed.";
output_tensors_.clear();
}
return output_tensors_;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -25,7 +25,7 @@
#include "include/api/status.h"
#include "include/api/types.h"
namespace mindspore::api {
namespace mindspore {
struct AclTensorInfo {
void *device_data;
size_t buffer_size;
@ -45,14 +45,12 @@ class ModelProcess {
input_infos_(),
output_infos_() {}
~ModelProcess() {}
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
Status UnLoad();
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status PreInitModelResource();
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
// override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
@ -62,8 +60,9 @@ class ModelProcess {
private:
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
Status CheckAndInitInput(const std::vector<Buffer> &inputs);
Status BuildOutputs(std::vector<Buffer> *outputs);
Status CheckAndInitInput(const std::vector<MSTensor> &inputs);
Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list);
Status BuildOutputs(std::vector<MSTensor> *outputs);
Status InitInputsBuffer();
Status InitOutputsBuffer();
@ -80,7 +79,9 @@ class ModelProcess {
aclmdlDataset *outputs_;
std::vector<AclTensorInfo> input_infos_;
std::vector<AclTensorInfo> output_infos_;
std::vector<MSTensor> input_tensors_;
std::vector<MSTensor> output_tensors_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H

View File

@ -25,91 +25,51 @@
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api {
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
AscendGraphImpl::AscendGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_type_("Ascend"),
device_id_(Context::Instance().GetDeviceID()),
device_id_(GlobalContext::GetGlobalDeviceID()),
context_(nullptr),
inputs_(),
outputs_(),
inputs_info_(),
outputs_info_(),
input_names_(),
output_names_(),
init_flag_(false),
load_flag_(false) {}
AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); }
AscendGraphImpl::~AscendGraphImpl() {}
Status AscendGraphImpl::InitEnv() {
if (init_flag_) {
return SUCCESS;
}
RegAllOp();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;
MS_LOG(INFO) << "Start to init env.";
env_guard_ = MsEnvGuard::GetEnv(device_id_);
if (env_guard_ == nullptr) {
MS_LOG(ERROR) << "Env init failed.";
return kMCDeviceError;
}
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
<< " is available.";
return FAILED;
return kMCFailed;
}
session_impl_->Init(device_id_);
init_flag_ = true;
return SUCCESS;
}
Status AscendGraphImpl::FinalizeEnv() {
if (!init_flag_) {
return SUCCESS;
}
MS_LOG_INFO << "Start finalize env";
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
{
PythonEnvGuard guard;
if (!context::CloseTsd(ms_context)) {
MS_LOG(ERROR) << "CloseTsd failed!";
return FAILED;
}
}
init_flag_ = false;
MS_LOG(INFO) << "End finalize env";
return SUCCESS;
MS_LOG(INFO) << "InitEnv success.";
return kSuccess;
}
Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return SUCCESS;
return kSuccess;
} catch (std::exception &e) {
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
return FAILED;
return kMCFailed;
}
}
@ -128,104 +88,104 @@ Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &i
MS_ASSERT(session_impl_ != nullptr);
std::string error_msg;
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
return Status(INVALID_INPUTS, error_msg);
return Status(kMCInvalidInput, error_msg);
}
return SUCCESS;
return kSuccess;
}
Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
Status AscendGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
MS_EXCEPTION_IF_NULL(reply);
if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr";
return FAILED;
return kMCDeviceError;
}
rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Set Ascend rtCtx failed";
return FAILED;
return kMCDeviceError;
}
vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i];
auto input = inputs_[i];
auto item = request[i];
auto input = inputs_info_[i];
if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size();
return FAILED;
return kMCInvalidInput;
}
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Tensor copy failed";
return FAILED;
auto ret = memcpy_s(input->data_c(), input->Size(), item.MutableData(), item.DataSize());
if (ret != kSuccess) {
MS_LOG(ERROR) << "MSTensor copy failed";
return kMCFailed;
}
inputs.push_back(input);
}
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
last_inputs_ = inputs;
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
return kMCFailed;
}
last_outputs_ = outputs;
reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
return SUCCESS;
*reply = GetOutputs();
return kSuccess;
}
Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
return {};
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < inputs_.size(); i++) {
auto &tensor = inputs_[i];
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
std::vector<MSTensor> result(inputs_info_.size());
for (size_t i = 0; i < inputs_info_.size(); ++i) {
auto &tensor = inputs_info_[i];
void *data = nullptr;
size_t data_size = tensor->Size();
if (i < last_inputs_.size()) {
data = last_inputs_[i]->data_c();
data_size = last_inputs_[i]->Size();
}
result[i] =
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
}
return SUCCESS;
return result;
}
Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
return {};
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < outputs_.size(); i++) {
auto &tensor = outputs_[i];
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
std::vector<MSTensor> result(outputs_info_.size());
for (size_t i = 0; i < outputs_info_.size(); ++i) {
auto &tensor = outputs_info_[i];
void *data = nullptr;
size_t data_size = tensor->Size();
if (i < last_outputs_.size()) {
data = last_outputs_[i]->data_c();
data_size = last_outputs_[i]->Size();
}
result[i] =
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
}
return SUCCESS;
return result;
}
Status AscendGraphImpl::Load() {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
return INVALID_INPUTS;
return kMCInvalidInput;
}
const auto &graph_data = GraphImpl::MutableGraphData();
@ -234,34 +194,34 @@ Status AscendGraphImpl::Load() {
// init
Status ret = InitEnv();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";
return FAILED;
return ret;
}
// load model
if (!load_flag_) {
ret = CompileGraph(func_graph);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Compile graph model failed";
return FAILED;
return ret;
}
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed";
return FAILED;
return kMCInvalidInput;
}
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed";
return FAILED;
return kMCInvalidInput;
}
// save d context
rtError_t rt_ret = rtCtxGetCurrent(&context_);
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null";
return FAILED;
return kMCDeviceError;
}
MS_LOG(INFO) << "Load model success";
@ -271,44 +231,112 @@ Status AscendGraphImpl::Load() {
rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED;
return kMCDeviceError;
}
return SUCCESS;
return kSuccess;
}
Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
}
}
if (inputs.size() != inputs_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
return INVALID_INPUTS;
if (inputs.size() != inputs_info_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
<< inputs.size();
return kMCInvalidInput;
}
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs[i].DataSize() != inputs_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
<< inputs[i].DataSize();
return INVALID_INPUTS;
for (size_t i = 0; i < inputs_info_.size(); ++i) {
if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
<< ", given count " << inputs[i].DataSize();
return kMCInvalidInput;
}
}
if (ExecuteModel(inputs, outputs) != SUCCESS) {
Status ret = ExecuteModel(inputs, outputs);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
return ret;
}
if (outputs_.size() != outputs->size()) {
if (outputs_info_.size() != outputs->size()) {
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
<< outputs_.size();
return FAILED;
<< outputs_info_.size();
return kMCFailed;
}
return SUCCESS;
return kSuccess;
}
} // namespace mindspore::api
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(INFO) << "Start to init env.";
device_id_ = device_id;
RegAllOp();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
auto ret = rtSetDevice(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
MS_LOG(INFO) << "InitEnv success.";
errno_ = kSuccess;
}
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
MS_LOG(INFO) << "Start finalize env";
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return;
}
auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
}
errno_ = kSuccess;
MS_LOG(INFO) << "End finalize env";
}
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
std::shared_ptr<MsEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
acl_env = global_ms_env_.lock();
if (acl_env != nullptr) {
MS_LOG(INFO) << "Env has been initialized, skip.";
} else {
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return nullptr;
}
global_ms_env_ = acl_env;
MS_LOG(INFO) << "Env init success";
}
return acl_env;
}
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
} // namespace mindspore

View File

@ -28,40 +28,56 @@
#include "ir/anf.h"
#include "cxx_api/model/model_impl.h"
#include "runtime/context.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api {
namespace mindspore {
class AscendGraphImpl : public GraphCell::GraphImpl {
public:
AscendGraphImpl();
~AscendGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
class MsEnvGuard;
Status InitEnv();
Status FinalizeEnv();
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_;
std::string device_type_;
uint32_t device_id_;
rtContext_t context_;
std::vector<tensor::TensorPtr> inputs_;
std::vector<tensor::TensorPtr> outputs_;
std::vector<tensor::TensorPtr> inputs_info_;
std::vector<tensor::TensorPtr> outputs_info_;
std::vector<tensor::TensorPtr> last_inputs_;
std::vector<tensor::TensorPtr> last_outputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
bool init_flag_;
bool load_flag_;
std::shared_ptr<MsEnvGuard> env_guard_;
};
} // namespace mindspore::api
class AscendGraphImpl::MsEnvGuard {
public:
explicit MsEnvGuard(uint32_t device_id);
~MsEnvGuard();
Status GetErrno() const { return errno_; }
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
private:
static std::weak_ptr<MsEnvGuard> global_ms_env_;
static std::mutex global_ms_env_mutex_;
Status errno_;
uint32_t device_id_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

View File

@ -23,15 +23,15 @@
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api {
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_id_(Context::Instance().GetDeviceID()),
inputs_(),
outputs_(),
device_id_(GlobalContext::GetGlobalDeviceID()),
inputs_info_(),
outputs_info_(),
input_names_(),
output_names_(),
init_flag_(false),
@ -40,13 +40,13 @@ GPUGraphImpl::GPUGraphImpl()
Status GPUGraphImpl::InitEnv() {
if (init_flag_) {
MS_LOG(WARNING) << "Initialized again, return success.";
return SUCCESS;
return kSuccess;
}
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
return kMCFailed;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
@ -57,18 +57,18 @@ Status GPUGraphImpl::InitEnv() {
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice
<< " is available.";
return FAILED;
return kMCFailed;
}
session_impl_->Init(device_id_);
init_flag_ = true;
return SUCCESS;
return kSuccess;
}
Status GPUGraphImpl::FinalizeEnv() {
if (!init_flag_) {
MS_LOG(WARNING) << "Never initialize before, return success";
return SUCCESS;
return kSuccess;
}
MS_LOG_INFO << "Start finalize env";
@ -77,14 +77,14 @@ Status GPUGraphImpl::FinalizeEnv() {
init_flag_ = false;
MS_LOG(INFO) << "End finalize env";
return SUCCESS;
return kSuccess;
}
Status GPUGraphImpl::Load() {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
return INVALID_INPUTS;
return kMCInvalidInput;
}
const auto &graph_data = GraphImpl::MutableGraphData();
@ -93,38 +93,38 @@ Status GPUGraphImpl::Load() {
// init
Status ret = InitEnv();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";
return FAILED;
return kMCDeviceError;
}
ret = CompileGraph(func_graph);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Compile graph model failed";
return FAILED;
return kMCFailed;
}
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed";
return FAILED;
return kMCInvalidInput;
}
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed";
return FAILED;
return kMCInvalidInput;
}
load_flag_ = true;
return SUCCESS;
return kSuccess;
}
Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return SUCCESS;
return kSuccess;
} catch (std::exception &e) {
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
return FAILED;
return kMCFailed;
}
}
@ -139,118 +139,118 @@ std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::
}
}
Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
Status GPUGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
MS_EXCEPTION_IF_NULL(reply);
vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i];
auto input = inputs_[i];
auto input = inputs_info_[i];
if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size();
return FAILED;
return kMCInvalidInput;
}
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
if (ret != SUCCESS) {
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data().get(), item.DataSize());
if (ret != kSuccess) {
MS_LOG(ERROR) << "Tensor copy failed";
return FAILED;
return kMCFailed;
}
inputs.push_back(input);
}
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
last_inputs_ = inputs;
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
return kMCFailed;
}
last_outputs_ = outputs;
reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
return SUCCESS;
*reply = GetOutputs();
return kSuccess;
}
Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
}
}
if (inputs.size() != inputs_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
return INVALID_INPUTS;
if (inputs.size() != inputs_info_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
<< inputs.size();
return kMCInvalidInput;
}
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs[i].DataSize() != inputs_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
<< inputs[i].DataSize();
return INVALID_INPUTS;
for (size_t i = 0; i < inputs_info_.size(); ++i) {
if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
<< ", given count " << inputs[i].DataSize();
return kMCInvalidInput;
}
}
if (ExecuteModel(inputs, outputs) != SUCCESS) {
if (ExecuteModel(inputs, outputs) != kSuccess) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
return kMCFailed;
}
if (outputs_.size() != outputs->size()) {
if (outputs_info_.size() != outputs->size()) {
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
<< outputs_.size();
return FAILED;
<< outputs_info_.size();
return kMCFailed;
}
return SUCCESS;
return kSuccess;
}
Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> GPUGraphImpl::GetInputs() {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
return {};
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < inputs_.size(); i++) {
auto &tensor = inputs_[i];
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
std::vector<MSTensor> result(inputs_info_.size());
for (size_t i = 0; i < inputs_info_.size(); ++i) {
auto &tensor = inputs_info_[i];
void *data = nullptr;
size_t data_size = tensor->Size();
if (i < last_inputs_.size()) {
data = last_inputs_[i]->data_c();
data_size = last_inputs_[i]->Size();
}
result[i] =
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
}
return SUCCESS;
return result;
}
Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
if (!load_flag_) {
Status ret = Load();
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
return {};
}
}
GraphUtils::ClearIfNotNull(names);
GraphUtils::ClearIfNotNull(shapes);
GraphUtils::ClearIfNotNull(data_types);
GraphUtils::ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < outputs_.size(); i++) {
auto &tensor = outputs_[i];
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
std::vector<MSTensor> result(outputs_info_.size());
for (size_t i = 0; i < outputs_info_.size(); ++i) {
auto &tensor = outputs_info_[i];
void *data = nullptr;
size_t data_size = tensor->Size();
if (i < last_outputs_.size()) {
data = last_outputs_[i]->data_c();
data_size = last_outputs_[i]->Size();
}
result[i] =
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
}
return SUCCESS;
return result;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -25,20 +25,17 @@
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "cxx_api/model/model_impl.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api {
namespace mindspore {
class GPUGraphImpl : public GraphCell::GraphImpl {
public:
GPUGraphImpl();
~GPUGraphImpl() override = default;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
Status InitEnv();
@ -46,14 +43,16 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_;
std::string device_type_;
uint32_t device_id_;
std::vector<tensor::TensorPtr> inputs_;
std::vector<tensor::TensorPtr> outputs_;
std::vector<tensor::TensorPtr> inputs_info_;
std::vector<tensor::TensorPtr> outputs_info_;
std::vector<tensor::TensorPtr> last_inputs_;
std::vector<tensor::TensorPtr> last_outputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
bool init_flag_;
@ -63,5 +62,5 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
uint32_t batch_size_;
uint32_t workspace_size_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H

View File

@ -17,15 +17,19 @@
#include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
namespace mindspore {
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
Graph::~Graph() {}
Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
ModelType Graph::ModelType() const {
MS_EXCEPTION_IF_NULL(graph_data_);
return graph_data_->ModelType();
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -19,7 +19,7 @@
#include "framework/common/helper/model_helper.h"
#endif
namespace mindspore::api {
namespace mindspore {
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
if (model_type != ModelType::kMindIR) {
@ -72,4 +72,4 @@ Buffer Graph::GraphData::GetOMData() const {
return om_data_;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -24,7 +24,7 @@
#include "include/api/types.h"
#include "ir/func_graph.h"
namespace mindspore::api {
namespace mindspore {
class Graph::GraphData {
public:
GraphData();
@ -46,5 +46,5 @@ class Graph::GraphData {
Buffer om_data_;
enum ModelType model_type_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H

View File

@ -26,7 +26,7 @@
#include "cxx_api/graph/graph_data.h"
#include "utils/utils.h"
namespace mindspore::api {
namespace mindspore {
class GraphCell::GraphImpl {
public:
GraphImpl() = default;
@ -35,17 +35,14 @@ class GraphCell::GraphImpl {
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load() = 0;
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;
protected:
std::shared_ptr<Graph> graph_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

View File

@ -1,63 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#include <map>
#include <vector>
#include "include/api/types.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
class GraphUtils {
public:
static DataType TransTypeId2InferDataType(TypeId type_id) {
const std::map<TypeId, api::DataType> id2type_map{
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
};
auto it = id2type_map.find(type_id);
if (it != id2type_map.end()) {
return it->second;
}
MS_LOG(WARNING) << "Unsupported data id " << type_id;
return api::kMsUnknown;
}
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H

View File

@ -16,47 +16,53 @@
#include "cxx_api/model/acl/acl_model.h"
#include <memory>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/python_utils.h"
namespace mindspore::api {
namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
Status AclModel::Build() {
MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_);
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(options_map);
std::string options_str = GenerateOptionsStr(options_map);
MS_EXCEPTION_IF_NULL(options);
if (graph_cell_ != nullptr && options_str == options_str_) {
if (graph_cell_ != nullptr) {
MS_LOG(INFO) << "This model has been built, skip.";
return SUCCESS;
return kSuccess;
}
if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) {
MS_LOG(INFO) << "Note: Load om model and all build options will be ignored.";
graph_cell_ = std::make_shared<GraphCell>(graph_);
MS_EXCEPTION_IF_NULL(graph_cell_);
if (!options_map.empty()) {
MS_LOG(WARNING) << "All build options will be ignored.";
return kSuccess;
}
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
MS_EXCEPTION_IF_NULL(options);
std::string options_key = options->GenAclOptionsKey();
std::shared_ptr<Graph> graph;
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
MS_LOG(INFO) << "This options has been built, read cache.";
graph = iter->second;
} else {
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
model_converter_.set_options(options.get());
auto om_data = model_converter_.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
}
return SUCCESS;
graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
dynamic_size_graph_map_[options_key] = graph;
}
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
model_converter_.set_options(options.get());
auto om_data = model_converter_.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return FAILED;
}
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
return ret;
}
@ -64,64 +70,97 @@ Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
// save result
graph_cell_ = graph_cell;
options_ = std::move(options);
options_str_ = options_str;
MS_LOG(INFO) << "Build model success.";
return SUCCESS;
return kSuccess;
}
Status AclModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_LOG(INFO) << "Start to resize model.";
MS_EXCEPTION_IF_NULL(graph_);
if (graph_->ModelType() == ModelType::kOM) {
MS_LOG(ERROR) << "OM model is not supported to resize model.";
return kMCFailed;
}
auto origin_inputs = GetInputs();
if (inputs.size() != origin_inputs.size()) {
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
return kMCInvalidInput;
}
if (inputs.size() != dims.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
return kMCInvalidInput;
}
if (model_context_ == nullptr) {
model_context_ = std::make_shared<ModelContext>();
}
std::string input_shape_option;
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].Name() != origin_inputs[i].Name()) {
MS_LOG(ERROR) << "Invalid inputs " << i << " name " << inputs[i].Name() << " not match model input name "
<< origin_inputs[i].Name();
return kMCInvalidInput;
}
input_shape_option += inputs[i].Name() + ":";
for (size_t j = 0; j < dims[i].size(); ++j) {
input_shape_option += std::to_string(dims[i][j]);
if (j + 1 < dims[i].size()) {
input_shape_option += ",";
}
}
if (i + 1 < inputs.size()) {
input_shape_option += ";";
}
}
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
ModelContext::SetInputShape(model_context_, input_shape_option);
auto graph_cell_bak = std::move(graph_cell_);
auto ret = Build();
if (ret != kSuccess) {
MS_LOG(INFO) << "Resize build failed.";
graph_cell_ = std::move(graph_cell_bak);
return ret;
}
MS_LOG(INFO) << "Resize success.";
return kSuccess;
}
Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status AclModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status AclModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
return FAILED;
return kMCFailed;
}
if (graph_cell_ == nullptr) {
MS_LOG(WARNING) << "Model has not been built, it will be built with default options";
Status ret = Build({});
if (ret != SUCCESS) {
Status ret = Build();
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
return FAILED;
return ret;
}
}
MS_EXCEPTION_IF_NULL(graph_cell_);
Status ret = graph_cell_->Run(inputs, outputs);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed.";
return FAILED;
return ret;
}
return SUCCESS;
return kSuccess;
}
Status AclModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> AclModel::GetInputs() {
MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
return graph_cell_->GetInputs();
}
Status AclModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> AclModel::GetOutputs() {
MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
return graph_cell_->GetOutputs();
}
std::string AclModel::GenerateOptionsStr(const std::map<std::string, std::string> &options) {
std::string ret;
for (auto &[key, value] : options) {
ret += key + "^" + value + "^^";
}
return ret;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -31,30 +31,25 @@
#include "ir/tensor.h"
#include "ir/anf.h"
namespace mindspore::api {
namespace mindspore {
class AclModel : public ModelImpl {
public:
AclModel() : model_converter_(), options_(nullptr), options_str_() {}
AclModel() : model_converter_(), options_(nullptr) {}
~AclModel() = default;
Status Build(const std::map<std::string, std::string> &options_map) override;
Status Build() override;
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options);
std::shared_ptr<GraphCell> graph_cell_;
ModelConverter model_converter_;
std::unique_ptr<AclModelOptions> options_;
std::string options_str_;
std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H

View File

@ -18,23 +18,31 @@
#include "utils/log_adapter.h"
#include "external/ge/ge_api_types.h"
namespace mindspore::api {
static std::string ParseOption(const std::map<std::string, std::string> &options, const std::string &key) {
auto iter = options.find(key);
if (iter != options.end()) {
return iter->second;
}
return "";
}
namespace mindspore {
static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
{DataType::kNumberTypeFloat32, "FP32"},
{DataType::kNumberTypeUInt8, "UINT8"}};
AclModelOptions::AclModelOptions(const std::map<std::string, std::string> &options) {
// to acl
insert_op_cfg_path = ParseOption(options, kModelOptionInsertOpCfgPath);
input_format = ParseOption(options, kModelOptionInputFormat);
input_shape = ParseOption(options, kModelOptionInputShape);
output_type = ParseOption(options, kModelOptionOutputType);
precision_mode = ParseOption(options, kModelOptionPrecisionMode);
op_select_impl_mode = ParseOption(options, kModelOptionOpSelectImplMode);
AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
if (context == nullptr) {
return;
}
insert_op_cfg_path = ModelContext::GetInsertOpConfigPath(context);
input_format = ModelContext::GetInputFormat(context);
input_shape = ModelContext::GetInputShape(context);
auto out_type = ModelContext::GetOutputType(context);
auto iter = kSupportedDtypeOptionMap.find(out_type);
if (out_type == DataType::kTypeUnknown) {
// do nothing
} else if (iter == kSupportedDtypeOptionMap.end()) {
MS_LOG(WARNING) << "Unsupported output type " << out_type << ", use FP32 as default.";
} else {
output_type = iter->second;
}
precision_mode = ModelContext::GetPrecisionMode(context);
op_select_impl_mode = ModelContext::GetOpSelectImplMode(context);
}
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
@ -69,4 +77,16 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
}
return {init_options, build_options};
}
} // namespace mindspore::api
std::string AclModelOptions::GenAclOptionsKey() const {
auto [init_options, build_options] = GenAclOptions();
std::string key_str;
for (auto &[key, value] : init_options) {
key_str += key + "^" + value + "^^";
}
for (auto &[key, value] : build_options) {
key_str += key + "^" + value + "^^";
}
return key_str;
}
} // namespace mindspore

View File

@ -20,12 +20,13 @@
#include <string>
#include <map>
#include <tuple>
#include <memory>
#include "include/api/types.h"
#include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore::api {
namespace mindspore {
struct AclModelOptions {
std::string output_node; // todo: at convert.cc::BuildGraph(), no atc options
// build options
std::string insert_op_cfg_path;
std::string input_format;
@ -35,12 +36,13 @@ struct AclModelOptions {
std::string op_select_impl_mode;
std::string soc_version = "Ascend310";
explicit AclModelOptions(const std::map<std::string, std::string> &options);
explicit AclModelOptions(const std::shared_ptr<Context> &context);
~AclModelOptions() = default;
// return tuple<init_options, build_options>
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
std::string GenAclOptionsKey() const;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H

View File

@ -22,9 +22,8 @@
#include "include/api/serialization.h"
#include "graph/model.h"
#include "cxx_api/model/model_converter_utils/multi_process.h"
#include "cxx_api/python_utils.h"
namespace mindspore::api {
namespace mindspore {
namespace {
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
transform::TensorOrderMap res;
@ -86,25 +85,25 @@ transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &
para->set_name(name);
}
transform::DfGraphConvertor convertor(anf_graph);
transform::DfGraphConvertor converter(anf_graph);
std::string net_id = "0";
std::string init_graph = "init_subgraph." + net_id;
std::string checkpoint_name = "save." + net_id;
convertor.set_training(false);
(void)convertor.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
(void)convertor.GenerateCheckpointGraph();
if (convertor.ErrCode() != 0) {
converter.set_training(false);
(void)converter.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
(void)converter.GenerateCheckpointGraph();
if (converter.ErrCode() != 0) {
transform::DfGraphManager::GetInstance().ClearGraph();
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode();
MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode();
return nullptr;
}
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), convertor.GetComputeGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), converter.GetComputeGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph());
transform::Status ret =
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph());
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph());
if (ret == transform::Status::SUCCESS) {
transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
}
@ -158,7 +157,7 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
auto df_graph = ConvertFuncGraphToAIR(func_graph);
if (df_graph == nullptr) {
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
return FAILED;
return kMCFailed;
}
ge::Model model;
ge::Buffer model_data;
@ -166,14 +165,14 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
auto ge_ret = model.Save(model_data);
if (ge_ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Save ge model to buffer failed.";
return FAILED;
return kMCFailed;
}
// send original model to child
auto status = multi_process->SendMsg(model_data.data(), model_data.size());
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Send original model to child process failed";
return FAILED;
return status;
}
// receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
@ -181,11 +180,11 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
};
status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED;
return status;
}
return SUCCESS;
return kSuccess;
};
auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
@ -196,25 +195,25 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
return reinterpret_cast<uint8_t *>(model.MutableData());
};
auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED;
return status;
}
Buffer model_result = LoadAscendIRInner(model);
if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from MindIR to OM failed";
return FAILED;
return kMCFailed;
}
// send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED;
return status;
}
return SUCCESS;
return kSuccess;
};
auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Convert MindIR model to OM model failed";
} else {
MS_LOG_INFO << "Convert MindIR model to OM model success";
@ -229,9 +228,9 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
MS_EXCEPTION_IF_NULL(multi_process);
// send original model to child
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Send original model to child process failed";
return FAILED;
return status;
}
// receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
@ -239,11 +238,11 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
};
status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED;
return status;
}
return SUCCESS;
return kSuccess;
};
auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
@ -254,25 +253,25 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
return reinterpret_cast<uint8_t *>(model.MutableData());
};
auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED;
return status;
}
Buffer model_result = LoadAscendIRInner(model);
if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from AIR to OM failed";
return FAILED;
return kMCFailed;
}
// send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED;
return status;
}
return SUCCESS;
return kSuccess;
};
auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) {
if (status != kSuccess) {
MS_LOG_ERROR << "Convert AIR model to OM model failed";
} else {
MS_LOG_INFO << "Convert AIR model to OM model success";
@ -326,4 +325,4 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
auto om_data = BuildAirModel(df_graph, init_options, build_options);
return om_data;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -27,7 +27,7 @@
#include "external/ge/ge_ir_build.h"
#include "cxx_api/model/acl/acl_model_options.h"
namespace mindspore::api {
namespace mindspore {
class ModelConverter {
public:
ModelConverter() : options_(nullptr) {}
@ -46,6 +46,5 @@ class ModelConverter {
Buffer LoadMindIRInner(const FuncGraphPtr &func_graph);
Buffer LoadAscendIRInner(const Buffer &model_data);
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H

View File

@ -19,49 +19,45 @@
#include "cxx_api/factory.h"
#include "utils/utils.h"
namespace mindspore::api {
Status Model::Build(const std::map<std::string, std::string> &options) {
namespace mindspore {
Status Model::Build() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Build(options);
return impl_->Build();
}
Status Model::Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Train(dataset, outputs);
return impl_->Resize(inputs, dims);
}
Status Model::Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Eval(dataset, outputs);
}
Status Model::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Predict(inputs, outputs);
}
Status Model::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> Model::GetInputs() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes);
return impl_->GetInputs();
}
Status Model::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> Model::GetOutputs() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
return impl_->GetOutputs();
}
Model::Model(const GraphCell &graph_cell)
: impl_(Factory<ModelImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
Model::Model(const GraphCell &graph_cell, const std::shared_ptr<Context> &model_context)
: impl_(Factory<ModelImpl>::Instance().Create(mindspore::GlobalContext::GetGlobalDeviceTarget())) {
if (impl_ == nullptr) {
MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed";
MS_LOG(EXCEPTION) << "Create session type " << mindspore::GlobalContext::GetGlobalDeviceTarget() << " failed";
}
MS_EXCEPTION_IF_NULL(graph_cell.GetGraph());
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
impl_->SetContext(model_context);
}
Model::Model(const std::vector<Output> &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; }
Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context) {
MS_LOG(EXCEPTION) << "Unsupported feature.";
}
Model::~Model() {}
@ -69,4 +65,4 @@ bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -24,7 +24,6 @@
#include "cxx_api/model/model_converter_utils/shared_memory.h"
namespace mindspore {
namespace api {
namespace {
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
}
@ -40,7 +39,7 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
memory_size_ = kSharedMemorySize; // 100 MB
SharedMemory shared_memory;
ret = shared_memory.Create(memory_size_);
if (!ret.IsSuccess()) {
if (ret != kSuccess) {
MS_LOG_ERROR << "Create shared memory failed";
return ret;
}
@ -48,10 +47,10 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
if (pid < 0) {
shared_memory.Destroy();
MS_LOG_ERROR << "Fork process to convert model failed";
return FAILED;
return kMEFailed;
}
ret = shared_memory.Attach();
if (!ret.IsSuccess()) {
if (ret != kSuccess) {
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
return ret;
}
@ -87,12 +86,12 @@ Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
Status ret;
try {
ret = parent_process(this);
if (!ret.IsSuccess()) {
if (ret != kSuccess) {
MS_LOG_ERROR << "Parent process process failed";
}
} catch (const std::runtime_error &ex) {
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
ret = FAILED;
ret = kMEFailed;
}
stopped_ = true;
send_msg_->stop = true;
@ -108,7 +107,7 @@ void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
try {
auto ret = child_process(this);
if (!ret.IsSuccess()) {
if (ret != kSuccess) {
MS_LOG_ERROR << "Child process process failed";
}
} catch (const std::runtime_error &ex) {
@ -138,14 +137,14 @@ Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
}
if (peer_stopped_) {
if (!send_msg_->read_finish_flag) {
return FAILED;
return kMEFailed;
}
break;
}
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
}
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
return SUCCESS;
return kSuccess;
}
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
@ -158,7 +157,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
usleep(1000); // 1ms
}
if (peer_stopped_) {
return FAILED;
return kMEFailed;
}
if (msg_buffer == nullptr) {
msg_len = receive_msg_->msg_total_len;
@ -170,7 +169,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
receive_msg_->read_finish_flag = true;
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
} while (msg_len > cur_offset);
return SUCCESS;
return kSuccess;
}
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
@ -200,6 +199,4 @@ void MultiProcess::HeartbeatThreadFuncInner() {
usleep(100000); // sleep 100 ms
}
}
} // namespace api
} // namespace mindspore

View File

@ -21,7 +21,6 @@
#include "include/api/status.h"
namespace mindspore {
namespace api {
struct MessageFlag {
uint64_t heartbeat = 0;
uint64_t stop = false;
@ -60,7 +59,5 @@ class MultiProcess {
Status ParentProcess(ProcessFuncCall parent_process);
void ChildProcess(ProcessFuncCall child_process);
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H

View File

@ -20,26 +20,25 @@
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore {
namespace api {
Status SharedMemory::Create(uint64_t memory_size) {
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
if (shm_id_ == -1) {
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
return FAILED;
return kMCFailed;
}
MS_LOG_INFO << "shmget success, shm id " << shm_id_;
return SUCCESS;
return kSuccess;
}
Status SharedMemory::Attach() {
void *shmat_addr = shmat(shm_id_, nullptr, 0);
if (shmat_addr == reinterpret_cast<void *>(-1)) {
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
return FAILED;
return kMCFailed;
}
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
return SUCCESS;
return kSuccess;
}
void SharedMemory::Detach() {
@ -63,5 +62,4 @@ void SharedMemory::Destroy() {
MS_LOG_ERROR << errMsg;
}
}
} // namespace api
} // namespace mindspore

View File

@ -20,7 +20,6 @@
#include "include/api/status.h"
namespace mindspore {
namespace api {
class SharedMemory {
public:
Status Create(uint64_t memory_size);
@ -33,7 +32,5 @@ class SharedMemory {
int shm_id_ = -1;
uint8_t *shmat_addr_ = nullptr;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H

View File

@ -21,28 +21,26 @@
#include <vector>
#include <memory>
#include <utility>
#include "include/api/context.h"
#include "include/api/model.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_data.h"
#include "utils/utils.h"
#include "ir/func_graph.h"
namespace mindspore::api {
namespace mindspore {
class ModelImpl {
public:
ModelImpl() = default;
virtual ~ModelImpl() = default;
virtual Status Build(const std::map<std::string, std::string> &options) = 0;
virtual Status Build() = 0;
virtual Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) = 0;
virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
virtual Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;
protected:
Status Load(const std::shared_ptr<GraphCell> &graph_cell) {
@ -61,11 +59,16 @@ class ModelImpl {
}
std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> model_context_;
private:
friend class Model;
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &model_context) {
if (model_context != nullptr) {
model_context_ = std::make_shared<Context>(*model_context);
}
}
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H

View File

@ -16,18 +16,78 @@
#include "cxx_api/model/ms/ms_model.h"
#include <memory>
#include "include/api/context.h"
#include "utils/ms_context.h"
#include "cxx_api/factory.h"
namespace mindspore {
namespace api {
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
API_FACTORY_REG(ModelImpl, GPU, MsModel);
Status MsModel::Build(const std::map<std::string, std::string> &) {
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key;
for (size_t i = 0; i < dims.size(); ++i) {
shape_key += std::to_string(i) + ":";
for (size_t j = 0; j < dims[i].size(); ++j) {
shape_key += std::to_string(dims[i][j]);
if (j + 1 < dims[i].size()) {
shape_key += ",";
}
}
if (i + 1 < dims.size()) {
shape_key += ";";
}
}
return shape_key;
}
std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key = GenerateShapeKey(dims);
if (auto iter = dynamic_size_graph_map_.find(shape_key); iter != dynamic_size_graph_map_.end()) {
MS_LOG(INFO) << "This options has been built, read cache.";
return iter->second;
}
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &inputs = func_graph->parameters();
if (dims.size() != inputs.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match model inputs size " << inputs.size();
return nullptr;
}
for (size_t i = 0; i < dims.size(); ++i) {
const auto &param = inputs[i];
auto shape_ptr = std::dynamic_pointer_cast<abstract::Shape>(param->Shape());
if (shape_ptr == nullptr) {
MS_LOG(ERROR) << "Inputs " << i << " is not supported to resize, debug string: " << param->DebugString();
return nullptr;
}
shape_ptr->shape() = dims[i];
}
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(func_graph, ModelType::kMindIR));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
return nullptr;
}
dynamic_size_graph_map_[shape_key] = graph_cell;
return graph_cell;
}
Status MsModel::Build() {
MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_);
if (graph_cell_ != nullptr) {
MS_LOG(INFO) << "This model has been built, skip.";
return kSuccess;
}
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -36,7 +96,7 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
return ret;
}
@ -44,55 +104,66 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
// save result
graph_cell_ = graph_cell;
MS_LOG(INFO) << "Build model success.";
return SUCCESS;
return kSuccess;
}
Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
Status MsModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_LOG(INFO) << "Start to resize model";
auto origin_inputs = GetInputs();
if (inputs.size() != origin_inputs.size()) {
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
return kMCInvalidInput;
}
if (inputs.size() != dims.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
return kMCInvalidInput;
}
auto graph_cell = GenerateGraphCell(dims);
if (graph_cell == nullptr) {
MS_LOG(ERROR) << "GenerateGraphCell failed.";
return kMCFailed;
}
MS_LOG(INFO) << "Resize model success.";
graph_cell_ = std::move(graph_cell);
return kSuccess;
}
Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status MsModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
Status MsModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
return FAILED;
return kMCFailed;
}
if (graph_cell_ == nullptr) {
MS_LOG(INFO) << "Model has not been built, it will be built with default options";
Status ret = Build({});
if (ret != SUCCESS) {
Status ret = Build();
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
return FAILED;
return ret;
}
}
MS_EXCEPTION_IF_NULL(graph_cell_);
Status ret = graph_cell_->Run(inputs, outputs);
if (ret != SUCCESS) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed.";
return FAILED;
return ret;
}
return SUCCESS;
return kSuccess;
}
Status MsModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> MsModel::GetInputs() {
MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
return graph_cell_->GetInputs();
}
Status MsModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
std::vector<MSTensor> MsModel::GetOutputs() {
MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
return graph_cell_->GetOutputs();
}
} // namespace api
} // namespace mindspore

View File

@ -33,26 +33,24 @@
#endif
namespace mindspore {
namespace api {
class MsModel : public ModelImpl {
public:
MsModel() {}
~MsModel() = default;
Status Build(const std::map<std::string, std::string> &options_map) override;
Status Build() override;
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
std::shared_ptr<GraphCell> graph_cell_;
std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H

View File

@ -15,7 +15,7 @@
*/
#include "include/api/ops/ops.h"
namespace mindspore::api {
namespace mindspore {
Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group)
: OpCell("Conv2D"),
@ -35,4 +35,4 @@ Output Conv2D::operator()(const Input &input1, const Input &input2) const {
std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) {
return {Output(shared_from_this(), inputs, 1)};
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -29,7 +29,7 @@ namespace py = pybind11;
static std::mutex init_mutex;
static bool Initialized = false;
namespace mindspore::api {
namespace mindspore {
static void RegAllOpFromPython() {
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize();
@ -143,4 +143,4 @@ PythonEnvGuard::~PythonEnvGuard() {
FinalizePython();
}
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
namespace mindspore::api {
namespace mindspore {
void RegAllOp();
bool PythonIsInited();
void InitPython();
@ -30,5 +30,5 @@ class PythonEnvGuard {
private:
bool origin_init_status_;
};
} // namespace mindspore::api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H

View File

@ -19,7 +19,7 @@
#include "utils/log_adapter.h"
#include "mindspore/core/load_mindir/load_model.h"
namespace mindspore::api {
namespace mindspore {
static Buffer ReadFile(const std::string &file) {
Buffer buffer;
if (file.empty()) {
@ -68,6 +68,22 @@ static Buffer ReadFile(const std::string &file) {
return buffer;
}
Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = nullptr;
try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed.";
}
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) {
return Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
}
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
}
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Buffer data = ReadFile(file);
if (data.Data() == nullptr) {
@ -77,7 +93,7 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
FuncGraphPtr anf_graph = nullptr;
try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
} catch (std::exception &e) {
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed.";
}
@ -90,21 +106,21 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
return kMEFailed;
}
Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
return kMEFailed;
}
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
return kMEFailed;
}
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
return kMEFailed;
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -17,17 +17,20 @@
#include <numeric>
#include "securec/include/securec.h"
#include "utils/utils.h"
#include "mindspore/core/ir/api_tensor_impl.h"
namespace mindspore::api {
const char *kDeviceTypeAscend310 = "Ascend310";
const char *kDeviceTypeAscend910 = "Ascend910";
const char *kDeviceTypeGpu = "GPU";
class DataImpl {
namespace mindspore {
class Buffer::Impl {
public:
DataImpl() : data_() {}
~DataImpl() = default;
DataImpl(const void *data, size_t data_len) { SetData(data, data_len); }
Impl() : data_() {}
~Impl() = default;
Impl(const void *data, size_t data_len) {
if (data != nullptr) {
(void)SetData(data, data_len);
} else {
ResizeData(data_len);
}
}
const void *Data() const { return data_.data(); }
void *MutableData() { return data_.data(); }
@ -66,132 +69,162 @@ class DataImpl {
std::vector<uint8_t> data_;
};
class Buffer::Impl : public DataImpl {
class TensorDefaultImpl : public MSTensor::Impl {
public:
Impl() : DataImpl() {}
~Impl() = default;
Impl(const void *data, size_t data_len) : DataImpl(data, data_len) {}
};
TensorDefaultImpl() : buffer_(), name_(), type_(DataType::kTypeUnknown), shape_() {}
~TensorDefaultImpl() override = default;
TensorDefaultImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: buffer_(data, data_len), name_(name), type_(type), shape_(shape) {}
class Tensor::Impl : public DataImpl {
public:
Impl() : DataImpl(), name_(), type_(DataType::kMsUnknown), shape_() {}
~Impl() = default;
Impl(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: DataImpl(data, data_len), name_(name), type_(type), shape_(shape) {}
const std::string &Name() const override { return name_; }
enum DataType DataType() const override { return type_; }
const std::vector<int64_t> &Shape() const override { return shape_; }
const std::string &Name() const { return name_; }
void SetName(const std::string &name) { name_ = name; }
api::DataType DataType() const { return type_; }
void SetDataType(api::DataType type) { type_ = type; }
void SetShape(const std::vector<int64_t> &shape) { shape_ = shape; }
const std::vector<int64_t> &Shape() const { return shape_; }
int64_t ElementNum() const {
std::vector<int64_t> shapex = Shape();
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
std::shared_ptr<const void> Data() const override {
return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
}
static int GetTypeSize(api::DataType type) {
static const std::map<api::DataType, size_t> type_size_map = {
{kMsBool, sizeof(bool)}, {kMsFloat64, sizeof(double)}, {kMsInt8, sizeof(int8_t)},
{kMsUint8, sizeof(uint8_t)}, {kMsInt16, sizeof(int16_t)}, {kMsUint16, sizeof(uint16_t)},
{kMsInt32, sizeof(int32_t)}, {kMsUint32, sizeof(uint32_t)}, {kMsInt64, sizeof(int64_t)},
{kMsUint64, sizeof(uint64_t)}, {kMsFloat16, sizeof(uint16_t)}, {kMsFloat32, sizeof(float)},
};
auto it = type_size_map.find(type);
if (it != type_size_map.end()) {
return it->second;
}
void *MutableData() override { return buffer_.MutableData(); }
size_t DataSize() const override { return buffer_.DataSize(); }
MS_LOG(WARNING) << "Cannot find data type " << type;
return 0;
bool IsDevice() const override { return false; }
std::shared_ptr<Impl> Clone() const override {
return std::make_shared<TensorDefaultImpl>(name_, type_, shape_, buffer_.Data(), buffer_.DataSize());
}
private:
Buffer buffer_;
std::string name_;
api::DataType type_;
enum DataType type_;
std::vector<int64_t> shape_;
};
Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
Tensor::Tensor(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {}
Tensor::~Tensor() = default;
class TensorReferenceImpl : public MSTensor::Impl {
public:
TensorReferenceImpl() : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_() {}
~TensorReferenceImpl() override = default;
TensorReferenceImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape) {}
Tensor Tensor::Clone() const {
const std::string &Name() const override { return name_; }
enum DataType DataType() const override { return type_; }
const std::vector<int64_t> &Shape() const override { return shape_; }
std::shared_ptr<const void> Data() const override {
return std::shared_ptr<const void>(data_, [](const void *) {});
}
void *MutableData() override { return const_cast<void *>(data_); }
size_t DataSize() const override { return data_size_; }
bool IsDevice() const override { return false; }
std::shared_ptr<Impl> Clone() const override {
return std::make_shared<TensorReferenceImpl>(name_, type_, shape_, data_, data_size_);
}
protected:
const void *data_;
size_t data_size_;
std::string name_;
enum DataType type_;
std::vector<int64_t> shape_;
};
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
return MSTensor(nullptr);
} catch (...) {
MS_LOG(ERROR) << "Unknown error occurred.";
return MSTensor(nullptr);
}
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
return MSTensor(nullptr);
} catch (...) {
MS_LOG(ERROR) << "Unknown error occurred.";
return MSTensor(nullptr);
}
}
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
MSTensor MSTensor::Clone() const {
MS_EXCEPTION_IF_NULL(impl_);
Tensor ret;
ret.impl_ = std::make_shared<Impl>(*impl_);
MSTensor ret;
ret.impl_ = impl_->Clone();
return ret;
}
const std::string &Tensor::Name() const {
const std::string &MSTensor::Name() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Name();
}
void Tensor::SetName(const std::string &name) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetName(name);
}
DataType Tensor::DataType() const {
enum DataType MSTensor::DataType() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->DataType();
}
void Tensor::SetDataType(api::DataType type) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDataType(type);
}
const std::vector<int64_t> &Tensor::Shape() const {
const std::vector<int64_t> &MSTensor::Shape() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Shape();
}
void Tensor::SetShape(const std::vector<int64_t> &shape) {
int64_t MSTensor::ElementNum() const {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetShape(shape);
const auto &shape = impl_->Shape();
if (shape.empty()) {
// element number of scalar is 1
return 1;
}
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
}
const void *Tensor::Data() const {
std::shared_ptr<const void> MSTensor::Data() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Data();
}
void *Tensor::MutableData() {
void *MSTensor::MutableData() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->MutableData();
}
size_t Tensor::DataSize() const {
size_t MSTensor::DataSize() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->DataSize();
}
bool Tensor::ResizeData(size_t data_len) {
bool MSTensor::IsDevice() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->ResizeData(data_len);
return impl_->IsDevice();
}
bool Tensor::SetData(const void *data, size_t data_len) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->SetData(data, data_len);
}
int64_t Tensor::ElementNum() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->ElementNum();
}
int Tensor::GetTypeSize(api::DataType type) { return Impl::GetTypeSize(type); }
Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
Buffer::~Buffer() = default;
@ -227,4 +260,4 @@ bool Buffer::SetData(const void *data, size_t data_len) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->SetData(data, data_len);
}
} // namespace mindspore::api
} // namespace mindspore

View File

@ -284,14 +284,7 @@ else()
endif()
add_dependencies(_c_dataengine mindspore_shared_lib)
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
set(MINDSPORE_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/cxx_api/CMakeFiles/mindspore_shared_lib.dir/objects.a)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib ${MINDSPORE_LINK_OBJECT})
else()
if(ENABLE_ACL)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
endif()
endif()
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
if(USE_GLOG)
target_link_libraries(_c_dataengine PRIVATE mindspore::glog)

View File

@ -26,28 +26,13 @@ if(ENABLE_PYTHON)
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
endif()
if(ENABLE_ACL)
add_library(cpp-API OBJECT
config.cc
datasets.cc
execute.cc
iterator.cc
minddata_eager.cc
transforms.cc
samplers.cc
text.cc
vision.cc
)
else()
add_library(cpp-API OBJECT
config.cc
datasets.cc
execute.cc
iterator.cc
transforms.cc
samplers.cc
text.cc
vision.cc
)
endif()
add_library(cpp-API OBJECT
config.cc
datasets.cc
execute.cc
iterator.cc
transforms.cc
samplers.cc
text.cc
vision.cc
)

View File

@ -1,142 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/include/de_tensor.h"
#include "minddata/dataset/include/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "mindspore/lite/include/ms_tensor.h"
#include "utils/hashing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#else
#include "mindspore/lite/src/common/log_adapter.h"
#endif
namespace mindspore {
namespace tensor {
MSTensor *DETensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
return new DETensor(data_type, shape);
}
MSTensor *DETensor::CreateTensor(const std::string &path) {
std::shared_ptr<dataset::Tensor> t;
(void)dataset::Tensor::CreateFromFile(path, &t);
return new DETensor(std::move(t));
}
MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) {
std::shared_ptr<dataset::Tensor> t;
// prepare shape info
std::vector<dataset::dsize_t> t_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
(void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type),
static_cast<uint8_t *>(data), &t);
return new DETensor(std::move(t));
}
DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size());
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
dataset::Tensor::CreateEmpty(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type), &this->tensor_impl_);
}
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
MSTensor *DETensor::ConvertToLiteTensor() {
// static MSTensor::CreateTensor is only for the LiteTensor
MSTensor *tensor = CreateTensor(this->data_type(), this->shape());
MS_ASSERT(tensor->Size() == this->Size());
memcpy_s(tensor->MutableData(), tensor->Size(), this->MutableData(), this->Size());
return tensor;
}
std::shared_ptr<dataset::Tensor> DETensor::tensor() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_;
}
TypeId DETensor::data_type() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return dataset::DETypeToMSType(this->tensor_impl_->type());
}
TypeId DETensor::set_data_type(TypeId data_type) {
MS_ASSERT(this->tensor_impl_ != nullptr);
if (data_type != this->data_type()) {
std::shared_ptr<dataset::Tensor> temp;
dataset::Tensor::CreateFromMemory(this->tensor_impl_->shape(), dataset::MSTypeToDEType(data_type),
this->tensor_impl_->GetBuffer(), &temp);
this->tensor_impl_ = temp;
}
return data_type;
}
std::vector<int> DETensor::shape() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
std::vector<dataset::dsize_t> t_shape = this->tensor_impl_->shape().AsVector();
std::vector<int> shape;
shape.reserve(t_shape.size());
std::transform(t_shape.begin(), t_shape.end(), std::back_inserter(shape),
[](dataset::dsize_t s) -> int { return static_cast<int>(s); });
return shape;
}
size_t DETensor::set_shape(const std::vector<int> &shape) {
MS_ASSERT(this->tensor_impl_ != nullptr);
std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size());
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
dataset::Status rc = this->tensor_impl_->Reshape(dataset::TensorShape(t_shape));
return shape.size();
}
int DETensor::DimensionSize(size_t index) const {
MS_ASSERT(this->tensor_impl_ != nullptr);
int dim_size = -1;
auto shape = this->shape();
if (index < shape.size()) {
dim_size = shape[index];
} else {
MS_LOG(ERROR) << "Dimension index is wrong: " << index;
}
return dim_size;
}
int DETensor::ElementsNum() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->Size();
}
size_t DETensor::Size() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->SizeInBytes();
}
void *DETensor::MutableData() {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->GetMutableBuffer();
}
} // namespace tensor
} // namespace mindspore

View File

@ -14,12 +14,11 @@
* limitations under the License.
*/
#include "minddata/dataset/core/tensor_row.h"
#ifdef ENABLE_ANDROID
#include "minddata/dataset/include/de_tensor.h"
#endif
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/tensor_op.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
@ -30,78 +29,85 @@
namespace mindspore {
namespace dataset {
Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}
Execute::Execute(std::shared_ptr<TensorOperation> op) { ops_.emplace_back(std::move(op)); }
/// \brief Destructor
Execute::~Execute() = default;
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops) : ops_(std::move(ops)) {}
#ifdef ENABLE_ANDROID
std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) {
// Build the op
if (op_ == nullptr) {
MS_LOG(ERROR) << "Input TensorOperation is not valid";
return nullptr;
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
// Validate input tensor
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data");
CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
// Validate and build runtime ops
std::vector<std::shared_ptr<TensorOp>> transforms;
for (int32_t i = 0; i < ops_.size(); i++) {
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
transforms.emplace_back(ops_[i]->Build());
}
std::shared_ptr<Tensor> de_input = std::dynamic_pointer_cast<tensor::DETensor>(input)->tensor();
if (de_input == nullptr) {
MS_LOG(ERROR) << "Input Tensor is not valid";
return nullptr;
}
std::shared_ptr<TensorOp> transform = op_->Build();
std::shared_ptr<Tensor> de_output;
Status rc = transform->Compute(de_input, &de_output);
// Convert mindspore::Tensor to dataset::Tensor
std::shared_ptr<dataset::Tensor> de_tensor;
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()),
MSTypeToDEType(static_cast<TypeId>(input.DataType())),
(const uchar *)(input.Data().get()), input.DataSize(), &de_tensor);
RETURN_IF_NOT_OK(rc);
if (rc.IsError()) {
// execution failed
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
return nullptr;
}
return std::make_shared<tensor::DETensor>(std::move(de_output));
}
#endif
// Apply transforms on tensor
for (auto &t : transforms) {
std::shared_ptr<dataset::Tensor> de_output;
RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output));
std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) {
// Build the op
if (op_ == nullptr) {
MS_LOG(ERROR) << "Input TensorOperation is not valid";
return nullptr;
// For next transform
de_tensor = std::move(de_output);
}
if (input == nullptr) {
MS_LOG(ERROR) << "Input Tensor is not valid";
return nullptr;
}
// will add validate params once API is set
std::shared_ptr<TensorOp> transform = op_->Build();
std::shared_ptr<Tensor> de_output;
Status rc = transform->Compute(input, &de_output);
if (rc.IsError()) {
// execution failed
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
return nullptr;
}
return de_output;
// Convert dataset::Tensor to mindspore::Tensor
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
return Status::OK();
}
Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
// Validate input tensor
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
for (auto &tensor : input_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data");
}
CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
TensorRow input, output;
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input));
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid");
std::shared_ptr<TensorOp> transform = op_->Build();
Status rc = transform->Compute(input, &output);
if (rc.IsError()) {
// execution failed
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
// Validate and build runtime ops
std::vector<std::shared_ptr<TensorOp>> transforms;
for (int32_t i = 0; i < ops_.size(); i++) {
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
transforms.emplace_back(ops_[i]->Build());
}
std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list));
TensorRow de_tensor_list;
for (auto &tensor : input_tensor_list) {
std::shared_ptr<dataset::Tensor> de_tensor;
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor);
RETURN_IF_NOT_OK(rc);
de_tensor_list.emplace_back(std::move(de_tensor));
}
// Apply transforms on tensor
for (auto &t : transforms) {
TensorRow de_output_list;
RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list));
// For next transform
de_tensor_list = std::move(de_output_list);
}
for (auto &tensor : de_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor->HasData(), "Apply transform failed, output tensor has no data");
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
output_tensor_list->emplace_back(ms_tensor);
}
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid");
return Status::OK();
}

View File

@ -1,154 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <unordered_map>
#include "minddata/dataset/include/minddata_eager.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace api {
MindDataEager::MindDataEager(std::vector<std::shared_ptr<dataset::TensorOperation>> ops) : ops_(ops) {}
// Helper function to convert Type from DE to MS
DataType ToMSType(dataset::DataType type) {
switch (dataset::DataType::Type(type)) {
case dataset::DataType::DE_BOOL:
return DataType::kMsBool;
case dataset::DataType::DE_UINT8:
return DataType::kMsUint8;
case dataset::DataType::DE_INT32:
return DataType::kMsInt32;
case dataset::DataType::DE_INT64:
return DataType::kMsInt64;
case dataset::DataType::DE_FLOAT32:
return DataType::kMsFloat32;
default:
return DataType::kMsUnknown;
}
}
// Helper function to convert Type from MS to DE
dataset::DataType ToDEType(DataType type) {
switch (type) {
case DataType::kMsBool:
return dataset::DataType(dataset::DataType::DE_BOOL);
case DataType::kMsUint8:
return dataset::DataType(dataset::DataType::DE_UINT8);
case DataType::kMsInt32:
return dataset::DataType(dataset::DataType::DE_INT32);
case DataType::kMsInt64:
return dataset::DataType(dataset::DataType::DE_INT64);
case DataType::kMsFloat32:
return dataset::DataType(dataset::DataType::DE_FLOAT32);
default:
return dataset::DataType(dataset::DataType::DE_UNKNOWN);
}
}
Status MindDataEager::LoadImageFromDir(const std::string &image_dir, std::vector<std::shared_ptr<Tensor>> *images) {
// Check target directory
dataset::Path image_dir_(image_dir);
if (!image_dir_.Exists() || !image_dir_.IsDirectory()) {
std::string err_msg = "Target directory: " + image_dir + " does not exist or not a directory.";
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
if (access(image_dir_.toString().c_str(), R_OK) == -1) {
std::string err_msg = "No access to target directory: " + image_dir;
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
// Start reading images and constructing tensors
auto path_itr = dataset::Path::DirIterator::OpenDirectory(&image_dir_);
while (path_itr->hasNext()) {
dataset::Path file = path_itr->next();
std::shared_ptr<dataset::Tensor> image;
dataset::Tensor::CreateFromFile(file.toString(), &image);
std::shared_ptr<Tensor> ms_image = std::make_shared<Tensor>("image", DataType(kMsUint8), image->shape().AsVector(),
image->GetBuffer(), image->SizeInBytes());
images->push_back(ms_image);
}
// Check if read images or not
if (images->empty()) {
std::string err_msg = "No images found in target directory: " + image_dir;
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
return Status(StatusCode::SUCCESS);
}
std::shared_ptr<Tensor> MindDataEager::operator()(std::shared_ptr<Tensor> input) {
// Validate ops
if (ops_.empty()) {
MS_LOG(ERROR) << "Input TensorOperation should be provided";
return nullptr;
}
for (int32_t i = 0; i < ops_.size(); i++) {
if (ops_[i] == nullptr) {
MS_LOG(ERROR) << "Input TensorOperation[" << i << "] is invalid or null";
return nullptr;
}
}
// Validate input tensor
if (input == nullptr) {
MS_LOG(ERROR) << "Input Tensor should not be null";
return nullptr;
}
// Start applying transforms in ops
std::shared_ptr<dataset::Tensor> de_input;
dataset::Tensor::CreateFromMemory(dataset::TensorShape(input->Shape()), ToDEType(input->DataType()),
(const uchar *)(input->Data()), &de_input);
for (int32_t i = 0; i < ops_.size(); i++) {
// Build runtime op and run
std::shared_ptr<dataset::Tensor> de_output;
std::shared_ptr<dataset::TensorOp> transform = ops_[i]->Build();
dataset::Status rc = transform->Compute(de_input, &de_output);
// check execution failed
if (rc.IsError()) {
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
return nullptr;
}
// For next transform
de_input = std::move(de_output);
}
// Convert DETensor to Tensor
if (!de_input->HasData()) {
MS_LOG(ERROR) << "Apply transform failed, output tensor has no data";
return nullptr;
}
std::shared_ptr<Tensor> output =
std::make_shared<Tensor>("transfomed", ToMSType(de_input->type()), de_input->shape().AsVector(),
de_input->GetBuffer(), de_input->SizeInBytes());
return output;
}
} // namespace api
} // namespace mindspore

View File

@ -29,25 +29,42 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
return execute;
}))
.def("__call__",
[](Execute &self, std::shared_ptr<Tensor> in) {
std::shared_ptr<Tensor> out = self(in);
if (out == nullptr) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED(
"Failed to execute op in eager mode, please check ERROR log above.");
[](Execute &self, const std::shared_ptr<Tensor> &de_tensor) {
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
Status rc = self(ms_tensor, &ms_tensor);
if (rc.IsError()) {
THROW_IF_ERROR([&rc]() {
RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString());
}());
}
return out;
std::shared_ptr<dataset::Tensor> de_output_tensor;
dataset::Tensor::CreateFromMemory(dataset::TensorShape(ms_tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(ms_tensor.DataType())),
(const uchar *)(ms_tensor.Data().get()),
ms_tensor.DataSize(), &de_output_tensor);
return de_output_tensor;
})
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
std::vector<std::shared_ptr<Tensor>> output_tensor_list;
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list));
if (output_tensor_list.empty()) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
}());
std::vector<MSTensor> ms_input_tensor_list;
std::vector<MSTensor> ms_output_tensor_list;
for (auto &tensor : input_tensor_list) {
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
ms_input_tensor_list.emplace_back(std::move(ms_tensor));
}
return output_tensor_list;
Status rc = self(ms_input_tensor_list, &ms_output_tensor_list);
if (rc.IsError()) {
THROW_IF_ERROR(
[&rc]() { RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); }());
}
std::vector<std::shared_ptr<dataset::Tensor>> de_output_tensor_list;
for (auto &tensor : ms_output_tensor_list) {
std::shared_ptr<dataset::Tensor> de_output_tensor;
dataset::Tensor::CreateFromMemory(
dataset::TensorShape(tensor.Shape()), MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_output_tensor);
de_output_tensor_list.emplace_back(std::move(de_output_tensor));
}
return de_output_tensor_list;
});
}));
} // namespace dataset

View File

@ -84,7 +84,8 @@ PYBIND_REGISTER(SliceOption, 0, ([](const py::module *m) {
}
if (!c_slice.valid()) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
THROW_IF_ERROR(
Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
return SliceOption(c_slice);
}))

View File

@ -354,7 +354,7 @@ PYBIND_REGISTER(
for (auto handle : py_sub.cast<py::list>()) {
py::tuple tp = handle.cast<py::tuple>();
if (tp.is_none() || tp.size() != 2) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
}
std::shared_ptr<TensorOperation> t_op;
if (py::isinstance<TensorOperation>(tp[0])) {
@ -366,11 +366,11 @@ PYBIND_REGISTER(
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
} else {
THROW_IF_ERROR(
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
Status(StatusCode::kMDUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
}
double prob = (tp[1]).cast<py::float_>();
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "prob needs to be with [0,1]."));
}
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
}

View File

@ -51,12 +51,12 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
f(cb_param);
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
return Status(StatusCode::kMDPyFuncException, e.what());
}
}
return Status::OK();

View File

@ -5,6 +5,7 @@ set(DATASET_CORE_SRC_FILES
config_manager.cc
cv_tensor.cc
data_type.cc
de_tensor.cc
global_context.cc
tensor.cc
tensor_helpers.cc

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/include/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "utils/hashing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#define ASSERT_NULL(ptr) MS_EXCEPTION_IF_NULL(ptr)
#else
#include "mindspore/lite/src/common/log_adapter.h"
#define ASSERT_NULL(ptr) MS_ASSERT((ptr) != nullptr)
#endif
namespace mindspore {
namespace dataset {
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_impl)
: tensor_impl_(tensor_impl),
name_("MindDataTensor"),
type_(static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()))),
shape_(tensor_impl_->shape().AsVector()) {}
const std::string &DETensor::Name() const { return name_; }
enum mindspore::DataType DETensor::DataType() const {
ASSERT_NULL(tensor_impl_);
return static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()));
}
size_t DETensor::DataSize() const {
ASSERT_NULL(tensor_impl_);
return tensor_impl_->SizeInBytes();
}
const std::vector<int64_t> &DETensor::Shape() const { return shape_; }
std::shared_ptr<const void> DETensor::Data() const {
return std::shared_ptr<const void>(tensor_impl_->GetBuffer(), [](const void *) {});
}
void *DETensor::MutableData() {
ASSERT_NULL(tensor_impl_);
return tensor_impl_->GetMutableBuffer();
}
bool DETensor::IsDevice() const { return false; }
std::shared_ptr<mindspore::MSTensor::Impl> DETensor::Clone() const { return std::make_shared<DETensor>(tensor_impl_); }
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
#include <string>
#include <vector>
#include <memory>
#include "include/api/types.h"
#include "mindspore/core/ir/api_tensor_impl.h"
#include "minddata/dataset/include/status.h"
#include "minddata/dataset/include/tensor.h"
namespace mindspore {
namespace dataset {
class DETensor : public mindspore::MSTensor::Impl {
public:
DETensor() = default;
~DETensor() override = default;
explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_impl);
const std::string &Name() const override;
enum mindspore::DataType DataType() const override;
size_t DataSize() const override;
const std::vector<int64_t> &Shape() const override;
std::shared_ptr<const void> Data() const override;
void *MutableData() override;
bool IsDevice() const override;
std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override;
private:
std::shared_ptr<dataset::Tensor> tensor_impl_;
std::string name_;
enum mindspore::DataType type_;
std::vector<int64_t> shape_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_

View File

@ -41,23 +41,17 @@
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor_helpers.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/util/status.h"
#include "utils/ms_utils.h"
#ifndef ENABLE_ANDROID
#include "proto/example.pb.h"
#else
#include "minddata/dataset/include/de_tensor.h"
#endif
#ifdef ENABLE_PYTHON
namespace py = pybind11;
#endif
namespace mindspore {
#ifdef ENABLE_ANDROID
namespace tensor {
class DETensor;
} // namespace tensor
#endif
namespace dataset {
class Tensor;
template <typename T>
@ -85,7 +79,7 @@ class Tensor {
/// \param other Tensor to be moved
Tensor(Tensor &&other) noexcept;
/// Move assigment operator
/// Move assignment operator
/// \param other Tensor to be moved
Tensor &operator=(Tensor &&other) noexcept;
@ -134,7 +128,7 @@ class Tensor {
#ifndef ENABLE_ANDROID
/// Create a tensor of type DE_STRING from a BytesList.
/// \param[in] bytes_list protobuf's Bytelist
/// \param[in] shape shape of the outout tensor
/// \param[in] shape shape of the output tensor
/// \param[out] out created Tensor
/// \return Status Code
static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out);
@ -292,7 +286,7 @@ class Tensor {
std::string err;
err += (data_ == nullptr) ? "data_ is nullptr \t" : "";
err += type_.IsCompatible<T>() ? "data type not compatible\t" : "";
return Status(StatusCode::kUnexpectedError, err);
return Status(StatusCode::kMDUnexpectedError, err);
}
}
@ -343,7 +337,7 @@ class Tensor {
void Invalidate();
/// Copy input tensor into self at the location index.
/// Index is a vector of axises which can be incomplete:
/// Index is a vector of axes which can be incomplete:
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
/// \param index
/// \param input
@ -686,9 +680,7 @@ class Tensor {
unsigned char *data_end_ = nullptr;
private:
#ifdef ENABLE_ANDROID
friend class tensor::DETensor;
#endif
friend class DETensor;
/// Slice numeric tensors.
Status SliceNumeric(TensorPtr *out, const std::vector<std::vector<dsize_t>> &indices, const TensorShape &shape);

View File

@ -73,6 +73,7 @@ if(ENABLE_CACHE)
engine-cache-server
_c_dataengine
_c_mindrecord
mindspore
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
@ -85,6 +86,7 @@ if(ENABLE_CACHE)
engine-cache-server
_c_dataengine
_c_mindrecord
mindspore
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
@ -103,6 +105,7 @@ if(ENABLE_CACHE)
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread)
target_link_libraries(cache_admin mindspore mindspore_shared_lib)
if(USE_GLOG)
target_link_libraries(cache_admin mindspore::glog)

View File

@ -22,10 +22,11 @@
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/path.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
ds::Status rc;
ms::Status rc;
ds::CacheAdminArgHandler args;
std::stringstream arg_stream;

View File

@ -89,7 +89,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Flag that this arg is used now
@ -101,7 +101,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
@ -113,7 +113,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
*arg_stream >> value_as_string;
if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Now, attempt to convert the value into it's numeric format for output
@ -121,7 +121,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
*out_arg = std::stoul(value_as_string);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
return Status::OK();
@ -133,7 +133,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Flag that this arg is used now
@ -145,7 +145,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
@ -158,12 +158,12 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
*arg_stream >> *out_arg;
} else {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
if (out_arg->empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
}
@ -176,7 +176,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Flag that this arg is used now
@ -188,7 +188,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
@ -200,7 +200,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
*arg_stream >> value_as_string;
if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Now, attempt to convert the value into it's string format for output
@ -208,7 +208,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
*out_arg = std::stof(value_as_string, nullptr);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
return Status::OK();
@ -224,7 +224,7 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
if (hostname_ != std::string(kCfgDefaultCacheHost)) {
std::string err_msg =
"Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used.";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
break;
}
@ -304,7 +304,7 @@ Status CacheAdminArgHandler::Validate() {
if (!trailing_args_.empty()) {
std::string err_msg = "Invalid arguments provided: " + trailing_args_;
err_msg += "\nPlease try `cache_admin --help` for more information";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
@ -312,18 +312,18 @@ Status CacheAdminArgHandler::Validate() {
if (command_id_ == CommandId::kCmdUnknown) {
std::string err_msg = "No command provided";
err_msg += "\nPlease try `cache_admin --help` for more information";
return Status(StatusCode::kSyntaxError, err_msg);
return Status(StatusCode::kMDSyntaxError, err_msg);
}
// Additional checks here
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
if (num_workers_ < 1 || num_workers_ > max_num_workers)
return Status(StatusCode::kSyntaxError,
return Status(StatusCode::kMDSyntaxError,
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..3).");
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535).");
return Status(StatusCode::kMDSyntaxError, "Memory cap ratio should be positive and no greater than 1");
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535).");
return Status::OK();
}
@ -467,9 +467,9 @@ Status CacheAdminArgHandler::StopServer(CommandId command_id) {
Status rc = rq->Wait();
if (rc.IsError()) {
msg.RemoveResourcesOnExit();
if (rc.IsNetWorkError()) {
if (rc == StatusCode::kMDNetWorkError) {
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
return Status(StatusCode::kNetWorkError, errMsg);
return Status(StatusCode::kMDNetWorkError, errMsg);
}
return rc;
}
@ -544,7 +544,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
if (WIFEXITED(status)) {
auto exit_status = WEXITSTATUS(status);
if (exit_status) {
return Status(StatusCode::kUnexpectedError, msg);
return Status(StatusCode::kMDUnexpectedError, msg);
} else {
// Not an error, some info message goes to stdout
std::cout << msg << std::endl;

View File

@ -75,7 +75,7 @@ Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, vo
do {
std::unique_lock<std::mutex> lock(mux_[slot]);
rc = shm_pool_[slot]->Allocate(sz, p);
if (rc.IsOutofMemory()) {
if (rc == StatusCode::kMDOutOfMemory) {
slot = (slot + 1) % shm_pool_.size();
}
} while (rc.IsError() && slot != begin_slot);

View File

@ -137,7 +137,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
Status CacheClient::AsyncWriteRow(const TensorRow &row) {
if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet);
return Status(StatusCode::kMDNotImplementedYet);
}
RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
return Status::OK();
@ -145,7 +145,7 @@ Status CacheClient::AsyncWriteRow(const TensorRow &row) {
Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet);
return Status(StatusCode::kMDNotImplementedYet);
} else {
Status rc;
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
@ -155,7 +155,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
TensorRow row;
RETURN_IF_NOT_OK(in->PopRow(&row));
rc = AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) {
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
tensor_table->push_back(row);
} else if (rc.IsError()) {
return rc;
@ -165,7 +165,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
// If not all of them can be sent async, return what's left back to the caller.
if (!tensor_table->empty()) {
in->set_tensor_table(std::move(tensor_table));
return Status(StatusCode::kNotImplementedYet);
return Status(StatusCode::kMDNotImplementedYet);
}
}
return Status::OK();
@ -225,7 +225,8 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
auto cache_state = static_cast<CacheServiceState>(out);
if (cache_state == CacheServiceState::kFetchPhase ||
(cache_state == CacheServiceState::kBuildPhase && cookie_.empty())) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
"Not an error and we should bypass the build phase");
}
} else {
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
@ -243,10 +244,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait();
bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey);
bool success = (rc.IsOk() || rc.StatusCode() == StatusCode::kMDDuplicateKey);
// If we get kDuplicateKey, it just means we aren't the first one to create the cache,
// and we will continue to parse the result.
if (rc.get_code() == StatusCode::kDuplicateKey) {
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
RETURN_IF_NOT_OK(rq->PostReply());
}
if (success) {
@ -443,7 +444,7 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
}
// If the size is too big, tell the user to send it directly.
if (sz > kAsyncBufferSize) {
return Status(StatusCode::kNotImplementedYet);
return Status(StatusCode::kMDNotImplementedYet);
}
std::unique_lock<std::mutex> lock(mux_);
// Check error from the server side while we have the lock;

View File

@ -66,7 +66,7 @@ enum class CacheServiceState : int8_t {
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<int32_t>(rc.get_code()));
reply->set_rc(static_cast<int32_t>(rc.StatusCode()));
reply->set_msg(rc.ToString());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number

View File

@ -98,7 +98,7 @@ Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffer
(*out_fbb) = std::move(fbb);
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
}

View File

@ -95,7 +95,7 @@ Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
std::unique_lock<std::mutex> lck(mux_);
auto r = req_.emplace(seqNo, std::move(tag));
if (!r.second) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__);
}
}
// Last step is to tag the request.
@ -124,7 +124,7 @@ Status CacheClientGreeter::WorkerEntry() {
} else {
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
}
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg);
Status remote_rc = Status(StatusCode::kMDNetWorkError, __LINE__, __FILE__, err_msg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.

View File

@ -25,7 +25,7 @@ Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
shmkey = ftok(unix_path.data(), 'a');
if (shmkey == (key_t)-1) {
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg);
return Status(errno == ENOENT ? StatusCode::kMDFileNotExist : StatusCode::kMDUnexpectedError, errMsg);
}
*out = shmkey;
return Status::OK();
@ -56,7 +56,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
CacheMsgBuf msg{
1,
};
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
msg.body.status.err_code = static_cast<int32_t>(rc.StatusCode());
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
msg.body.status.err_msg[rc.ToString().size()] = '\0';

View File

@ -25,16 +25,17 @@
#include <chrono>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset;
/// Start the server
/// \param argv
/// \return Status object
ds::Status StartServer(int argc, char **argv) {
ds::Status rc;
ms::Status StartServer(int argc, char **argv) {
ms::Status rc;
ds::CacheServer::Builder builder;
if (argc != 8) {
return ds::Status(ds::StatusCode::kSyntaxError);
return ms::Status(ms::StatusCode::kMDSyntaxError);
}
int32_t port = strtol(argv[3], nullptr, 10);
@ -53,7 +54,7 @@ ds::Status StartServer(int argc, char **argv) {
// is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
// A message queue for communication between parent and child (if we fork).
@ -80,13 +81,13 @@ ds::Status StartServer(int argc, char **argv) {
// failed to fork
if (pid < 0) {
std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else if (pid > 0) {
// Parent and will be responsible for remove the queue on exit.
msg.RemoveResourcesOnExit();
// Sleep one second and we attach to the msg que
std::this_thread::sleep_for(std::chrono::seconds(1));
ds::Status child_rc;
ms::Status child_rc;
rc = msg.ReceiveStatus(&child_rc);
if (rc.IsError()) {
return rc;
@ -101,7 +102,7 @@ ds::Status StartServer(int argc, char **argv) {
"logs (under "
<< ds::DefaultLogDir() << ") for any issues that may happen after startup\n";
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return ds::Status::OK();
return ms::Status::OK();
} else {
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
@ -110,7 +111,7 @@ ds::Status StartServer(int argc, char **argv) {
sid = setsid();
if (sid < 0) {
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
close(0);
close(1);
@ -137,10 +138,10 @@ ds::Status StartServer(int argc, char **argv) {
int main(int argc, char **argv) {
// This executable is not to be called directly, and should be invoked by cache_admin executable.
ds::Status rc = StartServer(argc, argv);
ms::Status rc = StartServer(argc, argv);
// Check result
if (rc.IsError()) {
auto errCode = rc.get_code();
auto errCode = rc.StatusCode();
auto errMsg = rc.ToString();
std::cerr << errMsg << std::endl;
return static_cast<int>(errCode);

View File

@ -136,7 +136,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
} else if (rc == StatusCode::kMDOutOfMemory) {
inx = (inx + 1) % num_slots;
} else {
return rc;
@ -162,7 +162,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
} else if (rc == StatusCode::kMDOutOfMemory) {
// Make the next arena and continue.
slot = (slot + 1) % num_segments;
} else {
@ -172,7 +172,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
}
// Handle the case we have done one round robin search.
if (ptr == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
return rc;
}

View File

@ -108,7 +108,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
bl.ptr = nullptr;
return rc;
}
} else if (rc.IsOutofMemory()) {
} else if (rc == StatusCode::kMDOutOfMemory) {
// If no memory, write to disk.
if (sm_ != nullptr) {
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
@ -116,7 +116,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
} else {
// If asked to spill to disk instead but there is no storage set up, simply return no memory
// instead.
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data");
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data");
}
} else {
return rc;
@ -125,7 +125,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
try {
rc = tree_->DoInsert(key, bl);
} catch (const std::bad_alloc &e) {
rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
rc = Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
// Duplicate key is treated as error and we will also free the memory.
if (rc.IsError() && bl.ptr != nullptr) {

View File

@ -223,7 +223,7 @@ Status CreateCacheRequest::Prepare() {
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
}
@ -277,7 +277,7 @@ Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
}

View File

@ -169,7 +169,7 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
max_avail -= mem_consumed;
if (max_avail <= 0) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
++it;
}
@ -179,12 +179,12 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
if (max_avail < avail_mem) {
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
if (req_mem > max_avail) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
}
@ -249,7 +249,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
client_id = cs->num_clients_.fetch_add(1);
all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
return Status(StatusCode::kMDOutOfMemory);
}
}
@ -276,7 +276,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK();
return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
}
Status CacheServer::DestroyCache(CacheRequest *rq) {
@ -306,7 +306,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto sz = rq->buf_data_size();
std::vector<const void *> buffers;
@ -326,7 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
reply->set_result(std::to_string(id));
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
return Status::OK();
@ -353,7 +353,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
Status rc;
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
@ -365,11 +365,11 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
} else {
auto state = cs->GetState();
if (state != CacheServiceState::kFetchPhase) {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Cache service is not in fetch phase. The current phase is " +
std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id));
} else {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Cookie mismatch. Client id: " + std::to_string(client_id));
}
}
@ -413,7 +413,7 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
Status rc;
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
// This is an internal request and is not tied to rpc. But need to post because there
@ -494,7 +494,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
auto &row_id_buf = rq->buf_data(0);
@ -551,7 +551,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
mem.resize(mem_sz);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
return Status(StatusCode::kMDOutOfMemory);
}
WritableSlice dest(mem.data(), mem_sz);
RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
@ -568,7 +568,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
@ -595,7 +595,7 @@ Status CacheServer::CacheSchema(CacheRequest *rq) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
auto &create_schema_buf = rq->buf_data(0);
@ -611,7 +611,7 @@ Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
@ -630,7 +630,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// First piece of data is the cookie
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
@ -639,7 +639,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
if (cookie == cs->cookie()) {
RETURN_IF_NOT_OK(cs->BuildPhaseDone());
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
return Status::OK();
@ -652,7 +652,7 @@ Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
std::vector<row_id_type> gap;
RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
@ -680,7 +680,7 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// First piece of data is the on/off flag
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
@ -747,7 +747,7 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto client_id = rq->client_id();
MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
@ -836,7 +836,7 @@ Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *inter
default:
std::string errMsg("Internal error, request type is not row request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
@ -860,7 +860,7 @@ Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
default:
std::string errMsg("Internal error, request type is not session request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
@ -931,7 +931,7 @@ Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
default:
std::string errMsg("Internal error, request type is not admin request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
@ -949,7 +949,7 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
} else {
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Notify it is done, and move on to the next request.
@ -1045,7 +1045,7 @@ Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
RETURN_UNEXPECTED_IF_NULL(q);
auto *p = new (std::nothrow) CacheServerRequest();
if (p == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
}
*q = p;
return Status::OK();
@ -1091,7 +1091,7 @@ Status CacheServer::DestroySession(CacheRequest *rq) {
} else {
std::string errMsg =
"Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
return Status(StatusCode::kFileNotExist, errMsg);
return Status(StatusCode::kMDFileNotExist, errMsg);
}
}
}
@ -1148,7 +1148,7 @@ Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto state = cs->GetState();
reply->set_result(std::to_string(static_cast<int8_t>(state)));
@ -1247,7 +1247,7 @@ Status CacheServer::Builder::IpcResourceCleanup() {
std::string errMsg = "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey, errMsg);
return Status(StatusCode::kMDDuplicateKey, errMsg);
}
return Status::OK();
}

View File

@ -419,7 +419,7 @@ class CacheServer : public Service {
Status GetRc() {
Status rc;
for (auto &cache_rc : rc_lists_) {
if (cache_rc.IsError() && !cache_rc.IsInterrupted() && rc.IsOk()) {
if (cache_rc.IsError() && cache_rc != StatusCode::kMDInterrupted && rc.IsOk()) {
rc = cache_rc;
}
}

View File

@ -42,7 +42,7 @@ Status CacheService::DoServiceStart() {
// Return an error if we use more than recommended memory.
std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
" while available system memory " + std::to_string(avail_mem);
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, errMsg);
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, errMsg);
}
memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
}
@ -79,7 +79,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
return Status(StatusCode::kMDOutOfMemory);
}
try {
// The first buffer is a flatbuffer which describes the rest of the buffers follow
@ -119,16 +119,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
}
// Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, all_data);
if (rc == Status(StatusCode::kDuplicateKey)) {
if (rc == Status(StatusCode::kMDDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
if (rc == StatusCode::kMDNoSpace) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
} else if (rc == StatusCode::kMDOutOfMemory) {
st_ = CacheServiceState::kOutOfMemory;
}
}
@ -152,7 +152,7 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
return Status(StatusCode::kMDOutOfMemory);
}
try {
// If we don't need to generate id, we need to find it from the buffer.
@ -172,16 +172,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
}
// Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, {src});
if (rc == Status(StatusCode::kDuplicateKey)) {
if (rc == Status(StatusCode::kMDDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
if (rc == StatusCode::kMDNoSpace) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
} else if (rc == StatusCode::kMDOutOfMemory) {
st_ = CacheServiceState::kOutOfMemory;
}
}
@ -307,7 +307,7 @@ Status CacheService::FetchSchema(std::string *out) const {
if (!mem.empty()) {
*out = std::move(mem);
} else {
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
return Status(StatusCode::kMDFileNotExist, __LINE__, __FILE__, "No schema has been cached");
}
return Status::OK();
}

View File

@ -36,7 +36,7 @@ Status CachePerfMsg::Receive(int32_t qID) {
auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR);
if (err == -1) {
if (errno == EIDRM) {
return Status(StatusCode::kInterrupted);
return Status(StatusCode::kMDInterrupted);
} else {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);

View File

@ -33,7 +33,7 @@ int main(int argc, char **argv) {
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
}
return static_cast<int>(rc.get_code());
return static_cast<int>(rc.StatusCode());
}
return 0;
}

View File

@ -100,5 +100,7 @@ class CachePerfRun {
};
} // namespace dataset
} // namespace mindspore
// todo: waiting for the master of the codes to refactor
#define get_code StatusCode
#define kDuplicateKey kMDDuplicateKey
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_

View File

@ -33,12 +33,12 @@ int main(int argc, char **argv) {
// If we hit any error, send the rc back to the parent.
if (rc.IsError()) {
ds::ErrorMsg proto;
proto.set_rc(static_cast<int32_t>(rc.get_code()));
proto.set_rc(static_cast<int32_t>(rc.StatusCode()));
proto.set_msg(rc.ToString());
ds::CachePerfMsg msg;
(void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto);
}
return static_cast<int>(rc.get_code());
return static_cast<int>(rc.StatusCode());
}
return 0;
}

View File

@ -115,5 +115,9 @@ class CachePipelineRun {
};
} // namespace dataset
} // namespace mindspore
// todo: waiting for the master of the codes to refactor
#define get_code StatusCode
#define kDuplicateKey kMDDuplicateKey
#define IsOutofMemory() StatusCode() == StatusCode::kMDOutOfMemory
#define IsNoSpace() StatusCode() == StatusCode::kMDNoSpace
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_

View File

@ -104,7 +104,7 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
if (errno == ENOSPC) {
return Status(StatusCode::kNoSpace, __LINE__, __FILE__);
return Status(StatusCode::kMDNoSpace, __LINE__, __FILE__);
} else {
RETURN_STATUS_UNEXPECTED(strerror(err));
}
@ -157,7 +157,7 @@ Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer
Status rc;
auto sc = new (std::nothrow) StorageContainer(path);
if (sc == nullptr) {
return Status(StatusCode::kOutOfMemory);
return Status(StatusCode::kMDOutOfMemory);
}
rc = sc->Create();
if (rc.IsOk()) {

View File

@ -96,9 +96,9 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
cont = containers_.at(num_containers - 1);
off64_t offset;
Status rc = cont->Insert(buf, &offset);
if (rc.get_code() == StatusCode::kBuddySpaceFull) {
if (rc.StatusCode() == StatusCode::kMDBuddySpaceFull) {
create_new_container = true;
// Remember how many containers we saw. In the next iteration we will do a comparision to see
// Remember how many containers we saw. In the next iteration we will do a comparison to see
// if someone has already created it.
last_num_container = num_containers;
} else if (rc.IsOk()) {

View File

@ -140,7 +140,7 @@ Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *
// If we already had an unknown dimension, then we cannot have a second unknown dimension.
// We only support the compute of a single unknown dim.
if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Requested shape has more than one unknown dimension!");
}
@ -312,12 +312,12 @@ Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::strin
}
// data type is mandatory field
if (type_str.empty())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"json schema file for column " + col_name + " has invalid or missing column type.");
// rank number is mandatory field
if (rank_value <= -1)
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"json schema file for column " + col_name + " must define a positive rank value.");
// Create the column descriptor for this column from the data we pulled from the json file
@ -425,7 +425,7 @@ Status DataSchema::AddColumn(const ColDescriptor &cd) {
Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
// Check if columns node exists. It is required for building schema from file.
if (js.find("columns") == js.end())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"\"columns\" node is required in the schema json file.");
return Status::OK();
}
@ -434,12 +434,12 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
// name to column index number.
Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) {
if (out_column_name_map == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map.");
}
for (int32_t i = 0; i < col_descs_.size(); ++i) {
if (col_descs_[i].name().empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Constructing column name map from schema, but found empty column name.");
}
(*out_column_name_map)[col_descs_[i].name()] = i;

View File

@ -290,7 +290,7 @@ Status ChildIterator::Drain() {
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
}
if (curr_buffer_->eof()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
}
return Status::OK();
}

View File

@ -122,7 +122,8 @@ Status BarrierOp::prepare(TensorQTable *const table) {
clean_up_ = false;
buffer_id_ = 0;
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"BarrierOp prepare phase requires a tensor table.");
}
// fill initial row
TensorRow new_row = {};
@ -150,7 +151,7 @@ Status BarrierOp::prepare(TensorQTable *const table) {
// fillBuffer always expects a new table to fill
Status BarrierOp::fillBuffer(TensorQTable *const table) {
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
}
TensorRow new_row = {};
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
@ -172,7 +173,7 @@ Status BarrierOp::blockCond() {
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
}
// we have condition name, however the flexibility is in python today
try {
@ -180,11 +181,11 @@ Status BarrierOp::blockCond() {
py::object ret_py_obj = condition_function_();
// Process the return value
if (!py::isinstance<py::bool_>(ret_py_obj)) {
return Status(StatusCode::kPyFuncException,
return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, condition wait function should return true/false.");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
return Status(StatusCode::kMDPyFuncException, e.what());
}
}
return Status::OK();

View File

@ -61,7 +61,7 @@ Status BatchOp::Builder::SanityCheck() {
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n"
: "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
}
#ifdef ENABLE_PYTHON
@ -261,7 +261,7 @@ Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatc
Status BatchOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
@ -338,23 +338,23 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
}
try {
py::object size = batch_size_func_(info);
*batch_size = size.cast<int32_t>();
if (*batch_size <= 0) {
return Status(StatusCode::kPyFuncException,
return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch size function should return an integer greater than 0.");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
return Status(StatusCode::kMDPyFuncException, e.what());
} catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException,
return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch size function should return an integer greater than 0.");
}
}
return Status(StatusCode::kOK, "Batch size func call succeed.");
return Status(StatusCode::kSuccess, "Batch size func call succeed.");
}
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
@ -362,7 +362,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
}
try {
// Prepare batch map call back parameters
@ -407,9 +407,9 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
output->push_back(std::move(output_batch));
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
return Status(StatusCode::kMDPyFuncException, e.what());
} catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException,
return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch map function should return a tuple of list of numpy array.");
}
}

View File

@ -191,7 +191,7 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba
if (bucket_index + 1 >= bucket_boundaries_.size()) {
std::string error_message =
"Invalid data, requested to pad to bucket boundary, element falls in last bucket.";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message);
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, error_message);
}
pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1;

View File

@ -42,7 +42,7 @@ BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePie
Status BuildSentencePieceVocabOp::operator()() {
if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
@ -84,10 +84,10 @@ Status BuildSentencePieceVocabOp::SentenceThread() {
sentencepiece::util::Status s_status =
sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
if (!s_status.ok()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message());
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, s_status.message());
} else {
if (vocab_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, sentencepiece vocab not set.");
}
vocab_->set_model_proto(model_proto);
@ -145,7 +145,7 @@ void BuildSentencePieceVocabOp::Next(std::string *sentence) {
if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
ret_status_ =
Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " +
new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank()));
read_done_ = true;

View File

@ -80,7 +80,7 @@ Status BuildVocabOp::WorkerEntry(int32_t worker_id) {
Status BuildVocabOp::operator()() {
// launch the collector thread
if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
}
RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks()));

View File

@ -233,7 +233,7 @@ Status CacheBase::UpdateColumnMapFromCache() {
// Get the schema from the server. It may not be there yet. So tolerate the error.
if (column_name_id_map_.empty()) {
rc = cache_client_->FetchSchema(&column_name_id_map_);
if (rc == Status(StatusCode::kFileNotExist)) {
if (rc == Status(StatusCode::kMDFileNotExist)) {
MS_LOG(DEBUG) << "Schema not in the server yet.";
rc = Status::OK();
}
@ -304,14 +304,14 @@ Status CacheBase::Prefetcher(int32_t worker_id) {
int32_t retry_count = 0;
do {
rc = PrefetchRows(prefetch_keys, &cache_miss);
if (rc.IsNetWorkError() && retry_count < max_retries) {
if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
// If we get some network error, we will attempt some retries
retry_count++;
} else if (rc.IsError()) {
MS_LOG(WARNING) << rc.ToString();
return rc;
}
} while (rc.IsNetWorkError());
} while (rc == StatusCode::kMDNetWorkError);
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
// we will put an empty row in the local cache.
if (rc.IsError() && AllowCacheMiss()) {

View File

@ -39,12 +39,12 @@ CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_
// Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
}
return Status::OK();
@ -59,7 +59,7 @@ Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
}
Status CacheLookupOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheLookupOp requires a sampler before it can be executed, but got nullptr.");
}
RETURN_IF_NOT_OK(RegisterResources());

View File

@ -129,7 +129,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
Status rc;
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
cache_missing_rows_ = false;
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
cache_client_->ServerRunningOutOfResources();
} else {
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
@ -156,7 +156,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
} else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources();
}
@ -188,9 +188,9 @@ Status CacheMergeOp::Cleaner() {
Status rc = rq->CheckCacheResult();
if (rc.IsError()) {
// If interrupt, time to quit.
if (rc.IsInterrupted()) {
if (rc == StatusCode::kMDInterrupted) {
return Status::OK();
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
} else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
// The server is hitting some limit and we will turn off caching from now on.
cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources();
@ -215,7 +215,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
// Construct the cache
const bool generate_ids = false;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
// We are told the cache has been created already.
MS_LOG(INFO) << "Cache created already";
rc = Status::OK();
@ -244,12 +244,12 @@ CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(
// Check if the required parameters are set by the builder.
Status CacheMergeOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
}
return Status::OK();
@ -316,7 +316,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
rc = cc->AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) {
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {

View File

@ -41,12 +41,12 @@ CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullp
// Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
}
return Status::OK();
@ -78,7 +78,7 @@ Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
}
RETURN_IF_NOT_OK(RegisterResources());
@ -113,7 +113,7 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
Status rc;
// Do the Async write if we attach to the shared memory.
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
if (rc.get_code() == StatusCode::kNotImplementedYet) {
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else if (rc.IsError()) {
return rc;
@ -169,9 +169,9 @@ Status CacheOp::WaitForCachingAllRows() {
BuildPhaseDone = true;
break;
case CacheServiceState::kOutOfMemory:
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory");
return Status(StatusCode::kMDOutOfMemory, "Cache server is running out of memory");
case CacheServiceState::kNoSpace:
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage");
return Status(StatusCode::kMDNoSpace, "Cache server is running of out spill storage");
case CacheServiceState::kNone:
case CacheServiceState::kError:
default:
@ -246,7 +246,7 @@ Status CacheOp::PrepareNodePostAction() {
// Construct the cache
const bool generate_ids = true;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
// We are told the cache has been created already. So we skip the build phase.
phase_ = Phase::kFetchPhase;
rc = Status::OK();

View File

@ -157,18 +157,14 @@ Status DeviceQueueOp::SendDataToAscend() {
TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
while (stop_send_ && ascend_keep_waiting_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
WaitContinueSignal();
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
}
if (create_data_info_queue_) {
DATA_INFO data_info;
@ -200,9 +196,8 @@ Status DeviceQueueOp::SendDataToAscend() {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
}
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
stop_send_ = true;
@ -219,13 +214,19 @@ Status DeviceQueueOp::SendDataToAscend() {
return Status::OK();
}
void DeviceQueueOp::WaitContinueSignal() const {
while (stop_send_ && ascend_keep_waiting_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
}
#endif
#ifdef ENABLE_TDTQUE
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
if (!create_data_info_queue_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
}
// This place has a race condition with operator(), so the first one
// arrive here will do the initialize work.
@ -241,7 +242,7 @@ Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
}
#else
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
}
#endif
@ -301,7 +302,7 @@ Status DeviceQueueOp::PushDataToGPU() {
}
handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, release_function);
if (handle == INVALID_HANDLE) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data.");
}
is_open = true;
}
@ -309,14 +310,14 @@ Status DeviceQueueOp::PushDataToGPU() {
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
}
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) {
if (ret == BlockQueueStatus_T::ERROR_INPUT) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it.");
} else {
if (!stop_send_) {
MS_LOG(DEBUG) << "Retry pushing data...";
@ -438,13 +439,13 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
for (auto &sub_item : *items) {
RETURN_IF_NOT_OK(pool_[worker_id]->Allocate(sub_item.data_len_, &sub_item.data_ptr_));
if (sub_item.data_ptr_ == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
}
const unsigned char *column_data = curr_row[i]->GetBuffer();
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {
MS_LOG(ERROR) << "memcpy_s failed!";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "memcpy_s failed.");
}
}

View File

@ -190,6 +190,7 @@ class DeviceQueueOp : public PipelineOp {
private:
#ifdef ENABLE_TDTQUE
void WaitContinueSignal() const;
Status SendDataToAscend();
bool ascend_keep_waiting_;
#endif

View File

@ -43,7 +43,7 @@ Status FilterOp::Builder::SanityCheck() {
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n"
: "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
}
FilterOp::Builder::Builder() {
@ -66,7 +66,7 @@ FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_wor
Status FilterOp::operator()() {
// The operator class just starts off threads by calling the tree_ function.
if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
}
filter_queues_.Init(num_workers_, oc_queue_size_);
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
@ -244,7 +244,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
RETURN_IF_NOT_OK(predicate_func_->Compute(input, &output));
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_predicate, {}));
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
}
// Visitor accept method for NodePass

Some files were not shown because too many files have changed in this diff Show More