cxx api refactor: tensor/status/model

This commit is contained in:
lixian 2021-01-26 17:06:34 +08:00 committed by zhoufeng
parent f1009cb21b
commit 7d2fd6e76c
227 changed files with 4285 additions and 2786 deletions

View File

@ -65,7 +65,7 @@ install(
install( install(
TARGETS mindspore_shared_lib TARGETS mindspore_shared_lib
LIBRARY DESTINATION ${INSTALL_LIB_DIR} DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore COMPONENT mindspore
) )
@ -327,7 +327,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/minddata_eager.h ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/execute.h
DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include
COMPONENT mindspore COMPONENT mindspore
) )

View File

@ -109,6 +109,8 @@ if(PLATFORM_ARM64)
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
if(ENABLE_TOOLS) if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
@ -128,6 +130,8 @@ elseif(PLATFORM_ARM32)
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
if(ENABLE_TOOLS) if(ENABLE_TOOLS)
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
@ -162,6 +166,8 @@ elseif(WIN32)
endif() endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark) set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark)
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86} install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86}
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -182,6 +188,8 @@ else()
endif() endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}

View File

@ -24,7 +24,6 @@
#include "include/api/graph.h" #include "include/api/graph.h"
namespace mindspore { namespace mindspore {
namespace api {
class InputAndOutput; class InputAndOutput;
using Input = InputAndOutput; using Input = InputAndOutput;
using Output = InputAndOutput; using Output = InputAndOutput;
@ -35,7 +34,7 @@ class MS_API CellBase {
virtual ~CellBase() = default; virtual ~CellBase() = default;
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; } virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
virtual std::shared_ptr<CellBase> Clone() const = 0; virtual std::shared_ptr<CellBase> Clone() const = 0;
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; } virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
std::vector<Output> operator()(const std::vector<Input> &inputs) const; std::vector<Output> operator()(const std::vector<Input> &inputs) const;
}; };
@ -57,16 +56,16 @@ class MS_API ParameterCell final : public Cell<ParameterCell> {
ParameterCell(ParameterCell &&); ParameterCell(ParameterCell &&);
ParameterCell &operator=(ParameterCell &&); ParameterCell &operator=(ParameterCell &&);
explicit ParameterCell(const Tensor &); explicit ParameterCell(const MSTensor &);
ParameterCell &operator=(const Tensor &); ParameterCell &operator=(const MSTensor &);
explicit ParameterCell(Tensor &&); explicit ParameterCell(MSTensor &&);
ParameterCell &operator=(Tensor &&); ParameterCell &operator=(MSTensor &&);
Tensor GetTensor() const { return tensor_; } MSTensor GetTensor() const { return tensor_; }
private: private:
Tensor tensor_; MSTensor tensor_;
}; };
class MS_API OpCellBase : public CellBase { class MS_API OpCellBase : public CellBase {
@ -99,11 +98,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(const std::shared_ptr<Graph> &); explicit GraphCell(const std::shared_ptr<Graph> &);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; } const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
private: private:
friend class ModelImpl; friend class ModelImpl;
@ -119,8 +116,8 @@ class MS_API InputAndOutput {
~InputAndOutput() = default; ~InputAndOutput() = default;
// no explicit // no explicit
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit) InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index); InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
@ -132,6 +129,5 @@ class MS_API InputAndOutput {
std::vector<InputAndOutput> prev_; std::vector<InputAndOutput> prev_;
int32_t index_; int32_t index_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CELL_H #endif // MINDSPORE_INCLUDE_API_CELL_H

View File

@ -16,26 +16,49 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <map>
#include <any>
#include <string> #include <string>
#include <memory> #include <memory>
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace api { constexpr auto kDeviceTypeAscend310 = "Ascend310";
class MS_API Context { constexpr auto kDeviceTypeAscend910 = "Ascend910";
public:
static Context &Instance();
const std::string &GetDeviceTarget() const;
Context &SetDeviceTarget(const std::string &device_target);
uint32_t GetDeviceID() const;
Context &SetDeviceID(uint32_t device_id);
private: struct MS_API Context {
Context(); virtual ~Context() = default;
~Context(); std::map<std::string, std::any> params;
class ContextImpl; };
std::shared_ptr<ContextImpl> impl_;
struct MS_API GlobalContext : public Context {
static std::shared_ptr<Context> GetGlobalContext();
static void SetGlobalDeviceTarget(const std::string &device_target);
static std::string GetGlobalDeviceTarget();
static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();
};
struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context);
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H #endif // MINDSPORE_INCLUDE_API_CONTEXT_H

43
include/api/data_type.h Normal file
View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_
namespace mindspore {
enum class DataType : int {
kTypeUnknown = 0,
kObjectTypeString = 12,
kObjectTypeList = 13,
kObjectTypeTuple = 14,
kObjectTypeTensorType = 17,
kNumberTypeBool = 30,
kNumberTypeInt8 = 32,
kNumberTypeInt16 = 33,
kNumberTypeInt32 = 34,
kNumberTypeInt64 = 35,
kNumberTypeUInt8 = 37,
kNumberTypeUInt16 = 38,
kNumberTypeUInt32 = 39,
kNumberTypeUInt64 = 40,
kNumberTypeFloat16 = 42,
kNumberTypeFloat32 = 43,
kNumberTypeFloat64 = 44,
kNumberTypeEnd = 46,
// add new enum here
kInvalidType = INT32_MAX,
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H #ifndef MINDSPORE_INCLUDE_API_GRAPH_H
#define MINDSPORE_INCLUDE_API_GRAPH_H #define MINDSPORE_INCLUDE_API_GRAPH_H
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map> #include <map>
@ -24,21 +25,21 @@
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace api {
class MS_API Graph { class MS_API Graph {
public: public:
class GraphData; class GraphData;
explicit Graph(const std::shared_ptr<GraphData> &graph_data); explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data); explicit Graph(std::shared_ptr<GraphData> &&graph_data);
explicit Graph(std::nullptr_t);
~Graph(); ~Graph();
enum ModelType ModelType() const; enum ModelType ModelType() const;
bool operator==(std::nullptr_t) const;
private: private:
friend class GraphCell; friend class GraphCell;
friend class ModelImpl; friend class ModelImpl;
std::shared_ptr<GraphData> graph_data_; std::shared_ptr<GraphData> graph_data_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_GRAPH_H #endif // MINDSPORE_INCLUDE_API_GRAPH_H

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#include <string>
#include <memory>
#include <map>
#include <any>
#include "include/api/types.h"
namespace mindspore {
namespace lite {
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
typedef enum : uint32_t {
NO_BIND = 0, /**< no bind */
HIGHER_CPU = 1, /**< bind higher cpu first */
MID_CPU = 2 /**< bind middle cpu first */
} CpuBindMode;
class Allocator;
} // namespace lite
struct MS_API Context {
public:
static void Clear(const std::shared_ptr<Context> &contxet);
static void SetAsDefault(const std::shared_ptr<Context> &contxet);
static void SetVendorName(const std::shared_ptr<Context> &contxet, const std::string &name);
static std::string GetVendorName(const std::shared_ptr<Context> &contxet);
static void SetThreadNum(const std::shared_ptr<Context> &contxet, int num);
static int GetThreadNum(const std::shared_ptr<Context> &contxet);
static void SetAllocator(const std::shared_ptr<Context> &contxet, std::shared_ptr<lite::Allocator> alloc);
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &contxet);
static void ConfigCPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigCPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void SetCPUBindMode(const std::shared_ptr<Context> &contxet, lite::CpuBindMode mode);
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &contxet);
static void ConfigGPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUEnabled(const std::shared_ptr<Context> &contxet);
static void ConfigGPUFp16(const std::shared_ptr<Context> &contxet, bool config);
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &contxet);
static void ConfigNPU(const std::shared_ptr<Context> &contxet, bool config);
static bool IfNPUEnabled(const std::shared_ptr<Context> &contxet);
static void SetNPUFrequency(const std::shared_ptr<Context> &contxet, int freq);
static int GetNPUFrequency(const std::shared_ptr<Context> &contxet);
private:
std::map<std::string, std::any> context_;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H

View File

@ -20,41 +20,36 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility>
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/cell.h" #include "include/api/cell.h"
namespace mindspore { namespace mindspore {
namespace api {
class ModelImpl; class ModelImpl;
// todo: minddata c++ interface struct Context;
class DataSet {};
class MS_API Model { class MS_API Model {
public: public:
explicit Model(const std::vector<Output> &network); explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
explicit Model(const GraphCell &graph); explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
~Model(); ~Model();
Model(const Model &) = delete; Model(const Model &) = delete;
void operator=(const Model &) = delete; void operator=(const Model &) = delete;
Status Build(const std::map<std::string, std::string> &options); Status Build();
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs); Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
static bool CheckModelSupport(const std::string &device_type, ModelType model_type); static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private: private:
std::shared_ptr<ModelImpl> impl_; std::shared_ptr<ModelImpl> impl_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H #endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -25,7 +25,6 @@
#include "include/api/cell.h" #include "include/api/cell.h"
namespace mindspore { namespace mindspore {
namespace api {
struct MS_API Conv2D : public OpCell<Conv2D> { struct MS_API Conv2D : public OpCell<Conv2D> {
Conv2D() : OpCell("Conv2D") {} Conv2D() : OpCell("Conv2D") {}
~Conv2D() override = default; ~Conv2D() override = default;
@ -45,6 +44,5 @@ struct MS_API Conv2D : public OpCell<Conv2D> {
std::vector<int> dilation = {1, 1, 1, 1}; std::vector<int> dilation = {1, 1, 1, 1};
int group = 1; int group = 1;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H #endif // MINDSPORE_INCLUDE_API_OPS_OPS_H

View File

@ -26,15 +26,14 @@
#include "include/api/graph.h" #include "include/api/graph.h"
namespace mindspore { namespace mindspore {
namespace api {
class MS_API Serialization { class MS_API Serialization {
public: public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type); static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -17,37 +17,129 @@
#define MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H
#include <string> #include <string>
#include <ostream>
#include <climits>
namespace mindspore { namespace mindspore {
namespace api { enum CompCode : uint32_t {
enum StatusCode { kCore = 0x00000000u,
SUCCESS = 0, kMD = 0x10000000u,
FAILED, kME = 0x20000000u,
INVALID_INPUTS, kMC = 0x30000000u,
// insert new status code here kLite = 0xF0000000u,
UNKNOWN = 0xFFFFFFFF };
enum StatusCode : uint32_t {
kSuccess = 0,
// Core
kCoreFailed = kCore | 0x1,
// MD
kMDOutOfMemory = kMD | 1,
kMDShapeMisMatch = kMD | 2,
kMDInterrupted = kMD | 3,
kMDNoSpace = kMD | 4,
kMDPyFuncException = kMD | 5,
kMDDuplicateKey = kMD | 6,
kMDPythonInterpreterFailure = kMD | 7,
kMDTDTPushFailure = kMD | 8,
kMDFileNotExist = kMD | 9,
kMDProfilingError = kMD | 10,
kMDBoundingBoxOutOfBounds = kMD | 11,
kMDBoundingBoxInvalidShape = kMD | 12,
kMDSyntaxError = kMD | 13,
kMDTimeOut = kMD | 14,
kMDBuddySpaceFull = kMD | 15,
kMDNetWorkError = kMD | 16,
kMDNotImplementedYet = kMD | 17,
// Make this error code the last one. Add new error code above it.
kMDUnexpectedError = kMD | 127,
// ME
kMEFailed = kME | 0x1,
kMEInvalidInput = kME | 0x2,
// MC
kMCFailed = kMC | 0x1,
kMCDeviceError = kMC | 0x2,
kMCInvalidInput = kMC | 0x3,
kMCInvalidArgs = kMC | 0x4,
// Lite // Common error code, range: [-1, -100
kLiteError = kLite | (0x0FFFFFFF & -1), /**< Common error code. */
kLiteNullptr = kLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/
kLiteParamInvalid = kLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/
kLiteNoChange = kLite | (0x0FFFFFFF & -4), /**< No change. */
kLiteSuccessExit = kLite | (0x0FFFFFFF & -5), /**< No error but exit. */
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
// Executor error code, range: [-100,-200)
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
// Graph error code, range: [-200,-300)
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
// Node error code, range: [-300,-400)
kLiteNotFindOp = kLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */
kLiteInvalidOpName = kLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */
kLiteInvalidOpAttr = kLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */
kLiteOpExecuteFailure = kLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */
// Tensor error code, range: [-400,-500)
kLiteFormatError = kLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */
// InferShape error code, range: [-500,-600)
kLiteInferError = kLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */
kLiteInferInvalid = kLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */
// User input param error code, range: [-600, 700)
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
}; };
class Status { class Status {
public: public:
Status() : status_code_(FAILED) {} Status() : status_code_(kSuccess), line_of_code_(-1) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg) {} : status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
~Status() = default; ~Status() = default;
bool IsSuccess() const { return status_code_ == SUCCESS; }
enum StatusCode StatusCode() const { return status_code_; } enum StatusCode StatusCode() const { return status_code_; }
std::string StatusMessage() const { return status_msg_; } const std::string &ToString() const { return status_msg_; }
int GetLineOfCode() const { return line_of_code_; }
const std::string &GetErrDescription() const { return status_msg_; }
const std::string &SetErrDescription(const std::string &err_description);
friend std::ostream &operator<<(std::ostream &os, const Status &s);
bool operator==(const Status &other) const { return status_code_ == other.status_code_; } bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
operator bool() const = delete;
explicit operator bool() const { return (status_code_ == kSuccess); }
explicit operator int() const { return static_cast<int>(status_code_); }
static Status OK() { return Status(StatusCode::kSuccess); }
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool IsError() const { return !IsOk(); }
static std::string CodeAsString(enum StatusCode c);
private: private:
enum StatusCode status_code_; enum StatusCode status_code_;
std::string status_msg_; std::string status_msg_;
int line_of_code_;
std::string file_name_;
std::string err_description_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H #endif // MINDSPORE_INCLUDE_API_STATUS_H

View File

@ -16,15 +16,20 @@
#ifndef MINDSPORE_INCLUDE_API_TYPES_H #ifndef MINDSPORE_INCLUDE_API_TYPES_H
#define MINDSPORE_INCLUDE_API_TYPES_H #define MINDSPORE_INCLUDE_API_TYPES_H
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "include/api/data_type.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default"))) #define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
namespace api { enum ModelType : uint32_t {
enum ModelType {
kMindIR = 0, kMindIR = 0,
kAIR = 1, kAIR = 1,
kOM = 2, kOM = 2,
@ -33,52 +38,38 @@ enum ModelType {
kUnknownType = 0xFFFFFFFF kUnknownType = 0xFFFFFFFF
}; };
enum DataType { class MS_API MSTensor {
kMsUnknown = 0,
kMsBool = 1,
kMsInt8 = 2,
kMsInt16 = 3,
kMsInt32 = 4,
kMsInt64 = 5,
kMsUint8 = 6,
kMsUint16 = 7,
kMsUint32 = 8,
kMsUint64 = 9,
kMsFloat16 = 10,
kMsFloat32 = 11,
kMsFloat64 = 12,
// insert new data type here
kInvalidDataType = 0xFFFFFFFF
};
class MS_API Tensor {
public: public:
Tensor(); class Impl;
Tensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, size_t data_len);
~Tensor(); static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
~MSTensor();
const std::string &Name() const; const std::string &Name() const;
void SetName(const std::string &name); enum DataType DataType() const;
api::DataType DataType() const;
void SetDataType(api::DataType type);
const std::vector<int64_t> &Shape() const; const std::vector<int64_t> &Shape() const;
void SetShape(const std::vector<int64_t> &shape); int64_t ElementNum() const;
const void *Data() const; std::shared_ptr<const void> Data() const;
void *MutableData(); void *MutableData();
size_t DataSize() const; size_t DataSize() const;
bool ResizeData(size_t data_len); bool IsDevice() const;
bool SetData(const void *data, size_t data_len);
int64_t ElementNum() const; MSTensor Clone() const;
static int GetTypeSize(api::DataType type); bool operator==(std::nullptr_t) const;
Tensor Clone() const;
private: private:
class Impl; friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
@ -101,21 +92,5 @@ class MS_API Buffer {
class Impl; class Impl;
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
}; };
extern MS_API const char *kDeviceTypeAscend310;
extern MS_API const char *kDeviceTypeAscend910;
extern MS_API const char *kDeviceTypeGpu;
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
// "high_precision" or "high_performance", default as "high_performance"
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H #endif // MINDSPORE_INCLUDE_API_TYPES_H

View File

@ -23,7 +23,7 @@ if(ENABLE_D)
endif() endif()
if(ENABLE_GPU) if(ENABLE_GPU)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc") file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc")
endif() endif()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
@ -45,8 +45,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf) -Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf)
else() else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} if(ENABLE_D OR ENABLE_ACL)
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
mindspore proto_input mindspore_gvar mindspore::protobuf)
endif()
endif() endif()
if(ENABLE_CPU) if(ENABLE_CPU)

View File

@ -18,7 +18,7 @@
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "cxx_api/graph/graph_impl.h" #include "cxx_api/graph/graph_impl.h"
namespace mindspore::api { namespace mindspore {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); } std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {} ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
@ -40,23 +40,23 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
return *this; return *this;
} }
ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {} ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
ParameterCell &ParameterCell::operator=(const Tensor &tensor) { ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
tensor_ = tensor.Clone(); tensor_ = tensor.Clone();
return *this; return *this;
} }
ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {} ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
ParameterCell &ParameterCell::operator=(Tensor &&tensor) { ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
tensor_ = tensor; tensor_ = tensor;
return *this; return *this;
} }
GraphCell::GraphCell(const Graph &graph) GraphCell::GraphCell(const Graph &graph)
: graph_(std::make_shared<Graph>(graph)), : graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
@ -64,7 +64,7 @@ GraphCell::GraphCell(const Graph &graph)
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
: graph_(graph), : graph_(graph),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
@ -72,13 +72,13 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
GraphCell::GraphCell(Graph &&graph) GraphCell::GraphCell(Graph &&graph)
: graph_(std::make_shared<Graph>(graph)), : graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_); executor_->SetGraph(graph_);
} }
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs); return executor_->Run(inputs, outputs);
} }
@ -88,25 +88,24 @@ Status GraphCell::Load() {
return executor_->Load(); return executor_->Load();
} }
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GraphCell::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes); return executor_->GetInputs();
} }
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GraphCell::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes); return executor_->GetOutputs();
} }
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const Tensor &tensor) InputAndOutput::InputAndOutput(const MSTensor &tensor)
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {} : cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {} InputAndOutput::InputAndOutput(MSTensor &&tensor)
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev, InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
int32_t index) int32_t index)
: cell_(cell), prev_(prev), index_(index) {} : cell_(cell), prev_(prev), index_(index) {}
} // namespace mindspore::api } // namespace mindspore

View File

@ -16,49 +16,119 @@
#include "include/api/context.h" #include "include/api/context.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::api { constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
class Context::ContextImpl { constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
public: constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
ContextImpl() : device_target_("NotSet"), device_id_(0) {} constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
~ContextImpl() = default; constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
const std::string &GetDeviceTarget() const { return device_target_; } // Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; } constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
uint32_t GetDeviceID() const { return device_id_; } constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; } // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
private: namespace mindspore {
std::string device_target_; template <class T>
uint32_t device_id_; static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
}; auto iter = context->params.find(key);
if (iter == context->params.end()) {
return T();
}
const std::any &value = iter->second;
if (value.type() != typeid(T)) {
return T();
}
Context &Context::Instance() { return std::any_cast<T>(value);
static Context context;
return context;
} }
const std::string &Context::GetDeviceTarget() const { std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
MS_EXCEPTION_IF_NULL(impl_); static std::shared_ptr<Context> g_context = std::make_shared<Context>();
return impl_->GetDeviceTarget(); return g_context;
} }
Context &Context::SetDeviceTarget(const std::string &device_target) { void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
impl_->SetDeviceTarget(device_target); MS_EXCEPTION_IF_NULL(global_context);
return *this; global_context->params[kGlobalContextDeviceTarget] = device_target;
} }
uint32_t Context::GetDeviceID() const { std::string GlobalContext::GetGlobalDeviceTarget() {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
return impl_->GetDeviceID(); MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
} }
Context &Context::SetDeviceID(uint32_t device_id) { void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
MS_EXCEPTION_IF_NULL(impl_); auto global_context = GetGlobalContext();
impl_->SetDeviceID(device_id); MS_EXCEPTION_IF_NULL(global_context);
return *this; global_context->params[kGlobalContextDeviceID] = device_id;
} }
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); } uint32_t GlobalContext::GetGlobalDeviceID() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}
Context::~Context() {} void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
} // namespace mindspore::api MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
}
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format;
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat);
}
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape;
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape);
}
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type;
}
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<enum DataType>(context, kModelOptionOutputType);
}
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode;
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode);
}
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
}
} // namespace mindspore

View File

@ -23,7 +23,7 @@
#include <utility> #include <utility>
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore::api { namespace mindspore {
template <class T> template <class T>
class Factory { class Factory {
using U = std::function<std::shared_ptr<T>()>; using U = std::function<std::shared_ptr<T>()>;
@ -79,5 +79,5 @@ class Registrar {
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); #DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

View File

@ -17,8 +17,8 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "acl/acl.h" #include "acl/acl.h"
namespace mindspore::api { namespace mindspore {
std::weak_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_; std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
std::mutex AclEnvGuard::global_acl_env_mutex_; std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
@ -42,7 +42,7 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
std::shared_ptr<AclEnvGuard> acl_env; std::shared_ptr<AclEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_acl_env_mutex_); std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
acl_env = global_acl_env_.lock(); acl_env = global_acl_env_;
if (acl_env != nullptr) { if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip."; MS_LOG(INFO) << "Acl has been initialized, skip.";
} else { } else {
@ -57,4 +57,4 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
} }
return acl_env; return acl_env;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -20,7 +20,7 @@
#include <mutex> #include <mutex>
#include "acl/acl_base.h" #include "acl/acl_base.h"
namespace mindspore::api { namespace mindspore {
class __attribute__((visibility("default"))) AclEnvGuard { class __attribute__((visibility("default"))) AclEnvGuard {
public: public:
explicit AclEnvGuard(std::string_view cfg_file); explicit AclEnvGuard(std::string_view cfg_file);
@ -29,10 +29,10 @@ class __attribute__((visibility("default"))) AclEnvGuard {
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file); static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
private: private:
static std::weak_ptr<AclEnvGuard> global_acl_env_; static std::shared_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_; static std::mutex global_acl_env_mutex_;
aclError errno_; aclError errno_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H

View File

@ -16,53 +16,50 @@
#include "cxx_api/graph/acl/acl_graph_impl.h" #include "cxx_api/graph/acl/acl_graph_impl.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "cxx_api/model/acl/model_converter.h" #include "cxx_api/model/acl/model_converter.h"
#include "cxx_api/python_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::api { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
AclGraphImpl::AclGraphImpl() AclGraphImpl::AclGraphImpl()
: init_flag_(false), : init_flag_(false),
load_flag_(false), load_flag_(false),
device_type_("AscendCL"), device_type_("AscendCL"),
device_id_(Context::Instance().GetDeviceID()), device_id_(GlobalContext::GetGlobalDeviceID()),
context_(nullptr), context_(nullptr),
acl_env_(nullptr) {} acl_env_(nullptr) {}
AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); } AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return ret;
} }
return model_process_.PredictFromHost(inputs, outputs); return model_process_.PredictFromHost(inputs, outputs);
} }
Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclGraphImpl::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return {};
} }
return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes); return model_process_.GetInputs();
} }
Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclGraphImpl::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed."; MS_LOG(ERROR) << "Prepare model resource failed.";
return FAILED; return {};
} }
return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes); return model_process_.GetOutputs();
} }
Status AclGraphImpl::LoadAclModel(Buffer om_data) { Status AclGraphImpl::LoadAclModel(Buffer om_data) {
@ -72,44 +69,44 @@ Status AclGraphImpl::LoadAclModel(Buffer om_data) {
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
if (acl_ret != ACL_ERROR_NONE) { if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
return FAILED; return kMCDeviceError;
} }
// acl init model resource // acl init model resource
model_process_.set_model_id(acl_model_id); model_process_.set_model_id(acl_model_id);
Status ret = model_process_.PreInitModelResource(); Status ret = model_process_.PreInitModelResource();
if (ret != SUCCESS) { if (ret != kSuccess) {
(void)aclmdlUnload(acl_model_id); (void)aclmdlUnload(acl_model_id);
MS_LOG(ERROR) << "Pre init model resource failed."; MS_LOG(ERROR) << "Pre init model resource failed.";
return FAILED; return ret;
} }
MS_LOG(INFO) << "Load acl model success."; MS_LOG(INFO) << "Load acl model success.";
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::InitEnv() { Status AclGraphImpl::InitEnv() {
if (init_flag_) { if (init_flag_) {
return SUCCESS; return kSuccess;
} }
acl_env_ = AclEnvGuard::GetAclEnv(""); acl_env_ = AclEnvGuard::GetAclEnv("");
if (acl_env_ == nullptr) { if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed."; MS_LOG(ERROR) << "Acl init failed.";
return FAILED; return kMCDeviceError;
} }
aclError ret = aclrtSetDevice(device_id_); aclError ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
return FAILED; return kMCDeviceError;
} }
MS_LOG(INFO) << "Open device " << device_id_ << " success"; MS_LOG(INFO) << "Open device " << device_id_ << " success";
ret = aclrtCreateContext(&context_, device_id_); ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl create context failed"; MS_LOG(ERROR) << "Acl create context failed";
return FAILED; return kMCDeviceError;
} }
MS_LOG(INFO) << "Create context success"; MS_LOG(INFO) << "Create context success";
@ -117,7 +114,7 @@ Status AclGraphImpl::InitEnv() {
ret = aclrtGetRunMode(&run_mode); ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl get run mode failed"; MS_LOG(ERROR) << "Acl get run mode failed";
return FAILED; return kMCDeviceError;
} }
bool is_device = (run_mode == ACL_DEVICE); bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device); model_process_.SetIsDevice(is_device);
@ -125,24 +122,24 @@ Status AclGraphImpl::InitEnv() {
MS_LOG(INFO) << "Init acl success, device id " << device_id_; MS_LOG(INFO) << "Init acl success, device id " << device_id_;
init_flag_ = true; init_flag_ = true;
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::FinalizeEnv() { Status AclGraphImpl::FinalizeEnv() {
if (!init_flag_) { if (!init_flag_) {
return SUCCESS; return kSuccess;
} }
aclError rt_ret = aclrtSetCurrentContext(context_); aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) { if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed"; MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED; return kMCDeviceError;
} }
Status ret = model_process_.UnLoad(); Status ret = model_process_.UnLoad();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Unload model inner failed."; MS_LOG(ERROR) << "Unload model inner failed.";
return FAILED; return ret;
} }
if (context_ != nullptr) { if (context_ != nullptr) {
@ -161,16 +158,16 @@ Status AclGraphImpl::FinalizeEnv() {
MS_LOG(INFO) << "End to reset device " << device_id_; MS_LOG(INFO) << "End to reset device " << device_id_;
init_flag_ = false; init_flag_ = false;
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::Load() { Status AclGraphImpl::Load() {
// check graph type // check graph type
if (graph_->ModelType() != ModelType::kOM) { if (graph_->ModelType() != ModelType::kOM) {
Status ret = ConvertToOM(); Status ret = ConvertToOM();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load Failed."; MS_LOG(ERROR) << "Load Failed.";
return FAILED; return ret;
} }
} }
@ -180,15 +177,15 @@ Status AclGraphImpl::Load() {
// init // init
Status ret = InitEnv(); Status ret = InitEnv();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed."; MS_LOG(ERROR) << "InitEnv failed.";
return FAILED; return ret;
} }
// load model // load model
if (!load_flag_) { if (!load_flag_) {
ret = LoadAclModel(om_data); ret = LoadAclModel(om_data);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load acl model failed."; MS_LOG(ERROR) << "Load acl model failed.";
return ret; return ret;
} }
@ -198,24 +195,24 @@ Status AclGraphImpl::Load() {
aclError rt_ret = aclrtSetCurrentContext(context_); aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) { if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed"; MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED; return kMCDeviceError;
} }
return SUCCESS; return kSuccess;
} }
Status AclGraphImpl::ConvertToOM() { Status AclGraphImpl::ConvertToOM() {
MS_LOG(INFO) << "Start convert to om model."; MS_LOG(INFO) << "Start convert to om model.";
if (graph_ == nullptr) { if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid graph_ is null."; MS_LOG(ERROR) << "Invalid graph_ is null.";
return FAILED; return kMCFailed;
} }
auto &graph_data = GraphImpl::MutableGraphData(); auto &graph_data = GraphImpl::MutableGraphData();
MS_EXCEPTION_IF_NULL(graph_data); MS_EXCEPTION_IF_NULL(graph_data);
if (graph_->ModelType() == ModelType::kOM) { if (graph_->ModelType() == ModelType::kOM) {
MS_LOG(INFO) << "This model has been built, skip."; MS_LOG(INFO) << "This model has been built, skip.";
return SUCCESS; return kSuccess;
} else if (graph_->ModelType() == ModelType::kMindIR) { } else if (graph_->ModelType() == ModelType::kMindIR) {
auto func_graph = graph_data->GetFuncGraph(); auto func_graph = graph_data->GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -223,13 +220,13 @@ Status AclGraphImpl::ConvertToOM() {
Buffer om_data = model_converter.LoadMindIR(func_graph); Buffer om_data = model_converter.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) { if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Convert MindIR to OM failed."; MS_LOG(ERROR) << "Convert MindIR to OM failed.";
return FAILED; return kMCFailed;
} }
graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM); graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM);
MS_LOG(INFO) << "Convert MindIR to OM success."; MS_LOG(INFO) << "Convert MindIR to OM success.";
return SUCCESS; return kSuccess;
} }
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
return FAILED; return kMCFailed;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -27,18 +27,16 @@
#include "cxx_api/graph/graph_impl.h" #include "cxx_api/graph/graph_impl.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
namespace mindspore::api { namespace mindspore {
class AclGraphImpl : public GraphCell::GraphImpl { class AclGraphImpl : public GraphCell::GraphImpl {
public: public:
AclGraphImpl(); AclGraphImpl();
~AclGraphImpl() override; ~AclGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override; Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private: private:
Status ConvertToOM(); Status ConvertToOM();
@ -56,5 +54,5 @@ class AclGraphImpl : public GraphCell::GraphImpl {
ModelProcess model_process_; ModelProcess model_process_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H

View File

@ -20,17 +20,19 @@
#include <map> #include <map>
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore::api { namespace mindspore {
static DataType TransToApiType(aclDataType data_type) { static DataType TransToApiType(aclDataType data_type) {
static const std::map<aclDataType, api::DataType> data_type_map = { static const std::map<aclDataType, enum DataType> data_type_map = {
{ACL_FLOAT16, api::kMsFloat16}, {ACL_FLOAT, api::kMsFloat32}, {ACL_DOUBLE, api::kMsFloat64}, {ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32},
{ACL_INT8, api::kMsInt8}, {ACL_INT16, api::kMsInt16}, {ACL_INT32, api::kMsInt32}, {ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8},
{ACL_INT64, api::kMsInt64}, {ACL_UINT8, api::kMsUint8}, {ACL_UINT16, api::kMsUint16}, {ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32},
{ACL_UINT32, api::kMsUint32}, {ACL_UINT64, api::kMsUint64}, {ACL_BOOL, api::kMsBool}, {ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8},
{ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32},
{ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool},
}; };
auto it = data_type_map.find(data_type); auto it = data_type_map.find(data_type);
if (it == data_type_map.end()) { if (it == data_type_map.end()) {
return api::kInvalidDataType; return DataType::kTypeUnknown;
} else { } else {
return it->second; return it->second;
} }
@ -51,7 +53,7 @@ inline static void PushbackIfNotNull(U *vec, T &&item) {
} }
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names, static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types, std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
std::vector<size_t> *mem_sizes) { std::vector<size_t> *mem_sizes) {
ClearIfNotNull(names); ClearIfNotNull(names);
ClearIfNotNull(shapes); ClearIfNotNull(shapes);
@ -66,41 +68,69 @@ static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_lis
} }
} }
static std::string ShapeToString(const std::vector<int64_t> &shape) {
std::string result = "[";
for (size_t i = 0; i < shape.size(); ++i) {
result += std::to_string(shape[i]);
if (i + 1 < shape.size()) {
result += ", ";
}
}
result += "]";
return result;
}
Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list,
std::vector<MSTensor> *tensor_list) {
MS_EXCEPTION_IF_NULL(tensor_list);
std::vector<std::string> names;
std::vector<std::vector<int64_t>> shapes;
std::vector<enum DataType> data_types;
std::vector<size_t> mem_sizes;
ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types, &mem_sizes);
tensor_list->clear();
if (names.size() != acl_tensor_list.size() || shapes.size() != acl_tensor_list.size() ||
data_types.size() != acl_tensor_list.size() || mem_sizes.size() != acl_tensor_list.size()) {
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names.size() << " shapes size " << shapes.size()
<< " data types size " << data_types.size() << " mem sizes size " << mem_sizes.size()
<< " acl_tensor_list size " << acl_tensor_list.size();
return kMCFailed;
}
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]);
auto ret = aclrtMemcpy((*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(),
acl_tensor_list[i].device_data, acl_tensor_list[i].buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << acl_tensor_list[i].buffer_size;
return kMCFailed;
}
}
return kSuccess;
}
Status ModelProcess::PreInitModelResource() { Status ModelProcess::PreInitModelResource() {
model_desc_ = aclmdlCreateDesc(); model_desc_ = aclmdlCreateDesc();
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_); aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
if (acl_ret != ACL_ERROR_NONE) { if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Read model desc failed"; MS_LOG(ERROR) << "Read model desc failed";
return FAILED; return kMCDeviceError;
} }
Status ret = InitInputsBuffer(); Status ret = InitInputsBuffer();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Create input buffer failed"; MS_LOG(ERROR) << "Create input buffer failed";
return FAILED; return ret;
} }
ret = InitOutputsBuffer(); ret = InitOutputsBuffer();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Create output buffer failed"; MS_LOG(ERROR) << "Create output buffer failed";
return FAILED; return ret;
} }
return SUCCESS; return kSuccess;
}
Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t *model_id) {
MS_EXCEPTION_IF_NULL(model_id);
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), model_id);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name;
return FAILED;
}
MS_LOG(INFO) << "Load model success " << file_name;
model_id_ = *model_id;
if (PreInitModelResource() != SUCCESS) {
aclmdlUnload(model_id_);
MS_LOG(ERROR) << "Pre init model resource failed, file name is " << file_name;
return FAILED;
}
return SUCCESS;
} }
Status ModelProcess::InitInputsBuffer() { Status ModelProcess::InitInputsBuffer() {
@ -113,8 +143,8 @@ Status ModelProcess::InitInputsBuffer() {
if (!is_run_on_device_) { // need to copy input/output to/from device if (!is_run_on_device_) { // need to copy input/output to/from device
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device input buffer faild , input size " << buffer_size; MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size;
return FAILED; return kMCDeviceError;
} }
} }
@ -125,7 +155,7 @@ Status ModelProcess::InitInputsBuffer() {
if (!is_run_on_device_) { if (!is_run_on_device_) {
aclrtFree(data_mem_buffer); aclrtFree(data_mem_buffer);
} }
return FAILED; return kMCDeviceError;
} }
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i); aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount); std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
@ -137,7 +167,7 @@ Status ModelProcess::InitInputsBuffer() {
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name}); input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name});
} }
MS_LOG(INFO) << "Create model inputs success"; MS_LOG(INFO) << "Create model inputs success";
return SUCCESS; return kSuccess;
} }
Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) { Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
@ -154,14 +184,14 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
if (!is_run_on_device_) { if (!is_run_on_device_) {
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size; MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
return FAILED; return kMCDeviceError;
} }
} else { } else {
ret = aclrtMallocHost(data_mem_buffer, buffer_size); ret = aclrtMallocHost(data_mem_buffer, buffer_size);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size; MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
return FAILED; return kMCDeviceError;
} }
} }
@ -169,16 +199,16 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
if (data_buffer == nullptr) { if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed"; MS_LOG(ERROR) << "Create Data Buffer failed";
free_data_buffer(*data_mem_buffer); free_data_buffer(*data_mem_buffer);
return FAILED; return kMCDeviceError;
} }
ret = aclmdlAddDatasetBuffer(dataset, data_buffer); ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed"; MS_LOG(ERROR) << "add data buffer failed";
free_data_buffer(*data_mem_buffer); free_data_buffer(*data_mem_buffer);
aclDestroyDataBuffer(data_buffer); aclDestroyDataBuffer(data_buffer);
return FAILED; return kMCDeviceError;
} }
return SUCCESS; return kSuccess;
} }
Status ModelProcess::InitOutputsBuffer() { Status ModelProcess::InitOutputsBuffer() {
@ -186,7 +216,7 @@ Status ModelProcess::InitOutputsBuffer() {
outputs_ = aclmdlCreateDataset(); outputs_ = aclmdlCreateDataset();
if (outputs_ == nullptr) { if (outputs_ == nullptr) {
MS_LOG(ERROR) << "Create input dataset failed"; MS_LOG(ERROR) << "Create input dataset failed";
return FAILED; return kMCDeviceError;
} }
size_t output_size = aclmdlGetNumOutputs(model_desc_); size_t output_size = aclmdlGetNumOutputs(model_desc_);
MS_LOG(INFO) << "output_size = " << output_size; MS_LOG(INFO) << "output_size = " << output_size;
@ -194,9 +224,9 @@ Status ModelProcess::InitOutputsBuffer() {
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i); auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void *data_mem_buffer = nullptr; void *data_mem_buffer = nullptr;
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != SUCCESS) { if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != kSuccess) {
MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size; MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size;
return FAILED; return kMCDeviceError;
} }
aclmdlIODims dims; aclmdlIODims dims;
ret = aclmdlGetOutputDims(model_desc_, i, &dims); ret = aclmdlGetOutputDims(model_desc_, i, &dims);
@ -207,7 +237,7 @@ Status ModelProcess::InitOutputsBuffer() {
} else { } else {
aclrtFreeHost(data_mem_buffer); aclrtFreeHost(data_mem_buffer);
} }
return FAILED; return kMCDeviceError;
} }
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i); aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount); std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
@ -219,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() {
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name}); output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name});
} }
MS_LOG(INFO) << "Create model output success"; MS_LOG(INFO) << "Create model output success";
return SUCCESS; return kSuccess;
} }
void ModelProcess::DestroyInputsDataset() { void ModelProcess::DestroyInputsDataset() {
@ -273,50 +303,60 @@ Status ModelProcess::UnLoad() {
auto ret = aclmdlUnload(model_id_); auto ret = aclmdlUnload(model_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed"; MS_LOG(ERROR) << "Unload model failed";
return FAILED; return kMCDeviceError;
} }
if (model_desc_ != nullptr) { if (model_desc_ != nullptr) {
ret = aclmdlDestroyDesc(model_desc_); ret = aclmdlDestroyDesc(model_desc_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed"; MS_LOG(ERROR) << "Unload model failed";
return FAILED; return kMCDeviceError;
} }
model_desc_ = nullptr; model_desc_ = nullptr;
} }
DestroyInputsBuffer(); DestroyInputsBuffer();
DestroyOutputsBuffer(); DestroyOutputsBuffer();
MS_LOG(INFO) << "End unload model " << model_id_; MS_LOG(INFO) << "End unload model " << model_id_;
return SUCCESS; return kSuccess;
} }
Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) { Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
aclError ret; aclError ret;
inputs_ = aclmdlCreateDataset(); inputs_ = aclmdlCreateDataset();
// check inputs // check inputs
if (inputs.size() != input_infos_.size()) { if (inputs.size() != input_infos_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given count " MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count "
<< inputs.size(); << inputs.size();
return INVALID_INPUTS; return kMCInvalidInput;
} }
for (size_t i = 0; i < input_infos_.size(); ++i) { for (size_t i = 0; i < input_infos_.size(); ++i) {
if (inputs[i].Shape() != input_infos_[i].dims) {
MS_LOG(INFO) << "Note: input " << i << " shape not match, required " << ShapeToString(input_infos_[i].dims)
<< ", given " << ShapeToString(inputs[i].Shape());
}
if (inputs[i].DataType() != TransToApiType(input_infos_[i].data_type)) {
MS_LOG(INFO) << "Note: input " << i << " data type not match, required "
<< TransToApiType(input_infos_[i].data_type) << ", given " << inputs[i].DataType();
}
if (inputs[i].DataSize() != input_infos_[i].buffer_size) { if (inputs[i].DataSize() != input_infos_[i].buffer_size) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << input_infos_[i].buffer_size
<< ", given count " << inputs[i].DataSize(); << ", given count " << inputs[i].DataSize();
return INVALID_INPUTS; return kMCInvalidInput;
} }
} }
// copy inputs // copy inputs
for (size_t i = 0; i < input_infos_.size(); ++i) { for (size_t i = 0; i < input_infos_.size(); ++i) {
const auto &info = input_infos_[i]; const auto &info = input_infos_[i];
const auto &input = inputs[i]; auto input = inputs[i];
const void *data = input.Data(); const void *data = input.MutableData();
void *input_buffer = nullptr; void *input_buffer = nullptr;
if (!is_run_on_device_) { if (!is_run_on_device_) {
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE); ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize(); MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
return FAILED; return kMCDeviceError;
} }
input_buffer = info.device_data; input_buffer = info.device_data;
} else { } else {
@ -325,23 +365,23 @@ Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size); auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
if (data_buffer == nullptr) { if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed"; MS_LOG(ERROR) << "Create Data Buffer failed";
return FAILED; return kMCDeviceError;
} }
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed"; MS_LOG(ERROR) << "add data buffer failed";
aclDestroyDataBuffer(data_buffer); aclDestroyDataBuffer(data_buffer);
return FAILED; return kMCDeviceError;
} }
} }
return SUCCESS; return kSuccess;
} }
Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status ModelProcess::PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
aclError acl_ret; aclError acl_ret;
Status ret = CheckAndInitInput(inputs); Status ret = CheckAndInitInput(inputs);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "check or init input failed"; MS_LOG(ERROR) << "check or init input failed";
DestroyInputsDataset(); DestroyInputsDataset();
return ret; // forward status error return ret; // forward status error
@ -361,50 +401,48 @@ Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vec
DestroyInputsDataset(); DestroyInputsDataset();
if (acl_ret != ACL_ERROR_NONE) { if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Execute Model Failed"; MS_LOG(ERROR) << "Execute Model Failed";
return FAILED; return kMCDeviceError;
} }
ret = BuildOutputs(outputs); ret = BuildOutputs(outputs);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Build outputs faield"; MS_LOG(ERROR) << "Build outputs failed";
return FAILED; return ret;
} }
MS_LOG(INFO) << "excute model success"; MS_LOG(INFO) << "Execute model success";
return SUCCESS; return kSuccess;
} }
Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) { Status ModelProcess::BuildOutputs(std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
aclError ret;
// copy outputs // copy outputs
outputs->clear(); outputs->clear();
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; auto inner_outputs = GetOutputs();
for (size_t i = 0; i < output_infos_.size(); ++i) { if (inner_outputs.size() != output_infos_.size()) {
const auto &info = output_infos_[i]; MS_LOG(ERROR) << "Invalid inner outputs size " << inner_outputs.size() << " do not match device output infos size "
outputs->emplace_back(Buffer()); << output_infos_.size();
auto output = outputs->rbegin(); return kMCFailed;
if (!output->ResizeData(info.buffer_size)) {
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
return FAILED;
}
ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << info.buffer_size;
return FAILED;
}
} }
return SUCCESS; (*outputs) = inner_outputs;
return kSuccess;
} }
Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> ModelProcess::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { Status ret = ConstructTensors(input_infos_, &input_tensors_);
ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes); if (ret != kSuccess) {
return SUCCESS; MS_LOG(ERROR) << "ConstructTensors failed.";
input_tensors_.clear();
}
return input_tensors_;
} }
Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> ModelProcess::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { Status ret = ConstructTensors(output_infos_, &output_tensors_);
ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes); if (ret != kSuccess) {
return SUCCESS; MS_LOG(ERROR) << "ConstructTensors failed.";
output_tensors_.clear();
}
return output_tensors_;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -25,7 +25,7 @@
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore::api { namespace mindspore {
struct AclTensorInfo { struct AclTensorInfo {
void *device_data; void *device_data;
size_t buffer_size; size_t buffer_size;
@ -45,14 +45,12 @@ class ModelProcess {
input_infos_(), input_infos_(),
output_infos_() {} output_infos_() {}
~ModelProcess() {} ~ModelProcess() {}
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
Status UnLoad(); Status UnLoad();
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); Status PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status PreInitModelResource(); Status PreInitModelResource();
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs();
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; std::vector<MSTensor> GetOutputs();
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
// override this method to avoid request/reply data copy // override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
@ -62,8 +60,9 @@ class ModelProcess {
private: private:
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
Status CheckAndInitInput(const std::vector<Buffer> &inputs); Status CheckAndInitInput(const std::vector<MSTensor> &inputs);
Status BuildOutputs(std::vector<Buffer> *outputs); Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list);
Status BuildOutputs(std::vector<MSTensor> *outputs);
Status InitInputsBuffer(); Status InitInputsBuffer();
Status InitOutputsBuffer(); Status InitOutputsBuffer();
@ -80,7 +79,9 @@ class ModelProcess {
aclmdlDataset *outputs_; aclmdlDataset *outputs_;
std::vector<AclTensorInfo> input_infos_; std::vector<AclTensorInfo> input_infos_;
std::vector<AclTensorInfo> output_infos_; std::vector<AclTensorInfo> output_infos_;
std::vector<MSTensor> input_tensors_;
std::vector<MSTensor> output_tensors_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H #endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H

View File

@ -25,91 +25,51 @@
#include "backend/session/executor_manager.h" #include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
AscendGraphImpl::AscendGraphImpl() AscendGraphImpl::AscendGraphImpl()
: session_impl_(nullptr), : session_impl_(nullptr),
graph_id_(0), graph_id_(0),
device_type_("Ascend"), device_type_("Ascend"),
device_id_(Context::Instance().GetDeviceID()), device_id_(GlobalContext::GetGlobalDeviceID()),
context_(nullptr), context_(nullptr),
inputs_(), inputs_info_(),
outputs_(), outputs_info_(),
input_names_(), input_names_(),
output_names_(), output_names_(),
init_flag_(false),
load_flag_(false) {} load_flag_(false) {}
AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); } AscendGraphImpl::~AscendGraphImpl() {}
Status AscendGraphImpl::InitEnv() { Status AscendGraphImpl::InitEnv() {
if (init_flag_) { MS_LOG(INFO) << "Start to init env.";
return SUCCESS; env_guard_ = MsEnvGuard::GetEnv(device_id_);
} if (env_guard_ == nullptr) {
RegAllOp(); MS_LOG(ERROR) << "Env init failed.";
auto ms_context = MsContext::GetInstance(); return kMCDeviceError;
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;
} }
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
if (session_impl_ == nullptr) { if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
<< " is available."; << " is available.";
return FAILED; return kMCFailed;
} }
session_impl_->Init(device_id_); session_impl_->Init(device_id_);
init_flag_ = true; MS_LOG(INFO) << "InitEnv success.";
return SUCCESS; return kSuccess;
}
Status AscendGraphImpl::FinalizeEnv() {
if (!init_flag_) {
return SUCCESS;
}
MS_LOG_INFO << "Start finalize env";
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
{
PythonEnvGuard guard;
if (!context::CloseTsd(ms_context)) {
MS_LOG(ERROR) << "CloseTsd failed!";
return FAILED;
}
}
init_flag_ = false;
MS_LOG(INFO) << "End finalize env";
return SUCCESS;
} }
Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr); MS_ASSERT(session_impl_ != nullptr);
try { try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return SUCCESS; return kSuccess;
} catch (std::exception &e) { } catch (std::exception &e) {
MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
return FAILED; return kMCFailed;
} }
} }
@ -128,104 +88,104 @@ Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &i
MS_ASSERT(session_impl_ != nullptr); MS_ASSERT(session_impl_ != nullptr);
std::string error_msg; std::string error_msg;
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
return Status(INVALID_INPUTS, error_msg); return Status(kMCInvalidInput, error_msg);
} }
return SUCCESS; return kSuccess;
} }
Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { Status AscendGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
MS_EXCEPTION_IF_NULL(reply); MS_EXCEPTION_IF_NULL(reply);
if (context_ == nullptr) { if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr"; MS_LOG(ERROR) << "rtCtx is nullptr";
return FAILED; return kMCDeviceError;
} }
rtError_t rt_ret = rtCtxSetCurrent(context_); rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Set Ascend rtCtx failed"; MS_LOG(ERROR) << "Set Ascend rtCtx failed";
return FAILED; return kMCDeviceError;
} }
vector<tensor::TensorPtr> inputs; vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) { for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i]; auto item = request[i];
auto input = inputs_[i]; auto input = inputs_info_[i];
if (input->Size() != item.DataSize()) { if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size(); << input->Size();
return FAILED; return kMCInvalidInput;
} }
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); auto ret = memcpy_s(input->data_c(), input->Size(), item.MutableData(), item.DataSize());
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Tensor copy failed"; MS_LOG(ERROR) << "MSTensor copy failed";
return FAILED; return kMCFailed;
} }
inputs.push_back(input); inputs.push_back(input);
} }
vector<tensor::TensorPtr> outputs = RunGraph(inputs); last_inputs_ = inputs;
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) { if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed"; MS_LOG(ERROR) << "Execute Model Failed";
return FAILED; return kMCFailed;
} }
last_outputs_ = outputs;
reply->clear(); reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), *reply = GetOutputs();
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); return kSuccess;
return SUCCESS;
} }
Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AscendGraphImpl::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return {};
} }
} }
GraphUtils::ClearIfNotNull(names); std::vector<MSTensor> result(inputs_info_.size());
GraphUtils::ClearIfNotNull(shapes); for (size_t i = 0; i < inputs_info_.size(); ++i) {
GraphUtils::ClearIfNotNull(data_types); auto &tensor = inputs_info_[i];
GraphUtils::ClearIfNotNull(mem_sizes); void *data = nullptr;
for (size_t i = 0; i < inputs_.size(); i++) { size_t data_size = tensor->Size();
auto &tensor = inputs_[i]; if (i < last_inputs_.size()) {
GraphUtils::PushbackIfNotNull(names, input_names_[i]); data = last_inputs_[i]->data_c();
GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); data_size = last_inputs_[i]->Size();
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); }
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); result[i] =
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
} }
return SUCCESS; return result;
} }
Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return {};
} }
} }
GraphUtils::ClearIfNotNull(names); std::vector<MSTensor> result(outputs_info_.size());
GraphUtils::ClearIfNotNull(shapes); for (size_t i = 0; i < outputs_info_.size(); ++i) {
GraphUtils::ClearIfNotNull(data_types); auto &tensor = outputs_info_[i];
GraphUtils::ClearIfNotNull(mem_sizes); void *data = nullptr;
for (size_t i = 0; i < outputs_.size(); i++) { size_t data_size = tensor->Size();
auto &tensor = outputs_[i]; if (i < last_outputs_.size()) {
GraphUtils::PushbackIfNotNull(names, output_names_[i]); data = last_outputs_[i]->data_c();
GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); data_size = last_outputs_[i]->Size();
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); }
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); result[i] =
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
} }
return result;
return SUCCESS;
} }
Status AscendGraphImpl::Load() { Status AscendGraphImpl::Load() {
// check graph type // check graph type
if (graph_->ModelType() != ModelType::kMindIR) { if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
return INVALID_INPUTS; return kMCInvalidInput;
} }
const auto &graph_data = GraphImpl::MutableGraphData(); const auto &graph_data = GraphImpl::MutableGraphData();
@ -234,34 +194,34 @@ Status AscendGraphImpl::Load() {
// init // init
Status ret = InitEnv(); Status ret = InitEnv();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed."; MS_LOG(ERROR) << "InitEnv failed.";
return FAILED; return ret;
} }
// load model // load model
if (!load_flag_) { if (!load_flag_) {
ret = CompileGraph(func_graph); ret = CompileGraph(func_graph);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Compile graph model failed"; MS_LOG(ERROR) << "Compile graph model failed";
return FAILED; return ret;
} }
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) { if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed"; MS_LOG_ERROR << "Get model inputs info failed";
return FAILED; return kMCInvalidInput;
} }
if (outputs_.empty() || outputs_.size() != output_names_.size()) { if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed"; MS_LOG_ERROR << "Get model outputs info failed";
return FAILED; return kMCInvalidInput;
} }
// save d context // save d context
rtError_t rt_ret = rtCtxGetCurrent(&context_); rtError_t rt_ret = rtCtxGetCurrent(&context_);
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null"; MS_LOG(ERROR) << "the ascend device context is null";
return FAILED; return kMCDeviceError;
} }
MS_LOG(INFO) << "Load model success"; MS_LOG(INFO) << "Load model success";
@ -271,44 +231,112 @@ Status AscendGraphImpl::Load() {
rtError_t rt_ret = rtCtxSetCurrent(context_); rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed"; MS_LOG(ERROR) << "Set the ascend device context failed";
return FAILED; return kMCDeviceError;
} }
return SUCCESS; return kSuccess;
} }
Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return ret;
} }
} }
if (inputs.size() != inputs_.size()) { if (inputs.size() != inputs_info_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
return INVALID_INPUTS; << inputs.size();
return kMCInvalidInput;
} }
for (size_t i = 0; i < inputs_.size(); ++i) { for (size_t i = 0; i < inputs_info_.size(); ++i) {
if (inputs[i].DataSize() != inputs_[i]->Size()) { if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
<< inputs[i].DataSize(); << ", given count " << inputs[i].DataSize();
return INVALID_INPUTS; return kMCInvalidInput;
} }
} }
if (ExecuteModel(inputs, outputs) != SUCCESS) {
Status ret = ExecuteModel(inputs, outputs);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Execute Model Failed"; MS_LOG(ERROR) << "Execute Model Failed";
return FAILED; return ret;
} }
if (outputs_.size() != outputs->size()) { if (outputs_info_.size() != outputs->size()) {
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
<< outputs_.size(); << outputs_info_.size();
return FAILED; return kMCFailed;
} }
return SUCCESS; return kSuccess;
} }
} // namespace mindspore::api
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(INFO) << "Start to init env.";
device_id_ = device_id;
RegAllOp();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
auto ret = rtSetDevice(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
MS_LOG(INFO) << "InitEnv success.";
errno_ = kSuccess;
}
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
MS_LOG(INFO) << "Start finalize env";
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return;
}
auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
}
errno_ = kSuccess;
MS_LOG(INFO) << "End finalize env";
}
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
std::shared_ptr<MsEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
acl_env = global_ms_env_.lock();
if (acl_env != nullptr) {
MS_LOG(INFO) << "Env has been initialized, skip.";
} else {
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return nullptr;
}
global_ms_env_ = acl_env;
MS_LOG(INFO) << "Env init success";
}
return acl_env;
}
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
} // namespace mindspore

View File

@ -28,40 +28,56 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "cxx_api/model/model_impl.h" #include "cxx_api/model/model_impl.h"
#include "runtime/context.h" #include "runtime/context.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api { namespace mindspore {
class AscendGraphImpl : public GraphCell::GraphImpl { class AscendGraphImpl : public GraphCell::GraphImpl {
public: public:
AscendGraphImpl(); AscendGraphImpl();
~AscendGraphImpl() override; ~AscendGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override; Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private: private:
class MsEnvGuard;
Status InitEnv(); Status InitEnv();
Status FinalizeEnv();
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_; std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_; uint32_t graph_id_;
std::string device_type_; std::string device_type_;
uint32_t device_id_; uint32_t device_id_;
rtContext_t context_; rtContext_t context_;
std::vector<tensor::TensorPtr> inputs_; std::vector<tensor::TensorPtr> inputs_info_;
std::vector<tensor::TensorPtr> outputs_; std::vector<tensor::TensorPtr> outputs_info_;
std::vector<tensor::TensorPtr> last_inputs_;
std::vector<tensor::TensorPtr> last_outputs_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
bool init_flag_;
bool load_flag_; bool load_flag_;
std::shared_ptr<MsEnvGuard> env_guard_;
}; };
} // namespace mindspore::api
class AscendGraphImpl::MsEnvGuard {
public:
explicit MsEnvGuard(uint32_t device_id);
~MsEnvGuard();
Status GetErrno() const { return errno_; }
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
private:
static std::weak_ptr<MsEnvGuard> global_ms_env_;
static std::mutex global_ms_env_mutex_;
Status errno_;
uint32_t device_id_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

View File

@ -23,15 +23,15 @@
#include "backend/session/executor_manager.h" #include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
namespace mindspore::api { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl() GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr), : session_impl_(nullptr),
graph_id_(0), graph_id_(0),
device_id_(Context::Instance().GetDeviceID()), device_id_(GlobalContext::GetGlobalDeviceID()),
inputs_(), inputs_info_(),
outputs_(), outputs_info_(),
input_names_(), input_names_(),
output_names_(), output_names_(),
init_flag_(false), init_flag_(false),
@ -40,13 +40,13 @@ GPUGraphImpl::GPUGraphImpl()
Status GPUGraphImpl::InitEnv() { Status GPUGraphImpl::InitEnv() {
if (init_flag_) { if (init_flag_) {
MS_LOG(WARNING) << "Initialized again, return success."; MS_LOG(WARNING) << "Initialized again, return success.";
return SUCCESS; return kSuccess;
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!"; MS_LOG(ERROR) << "Get Context failed!";
return FAILED; return kMCFailed;
} }
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
@ -57,18 +57,18 @@ Status GPUGraphImpl::InitEnv() {
if (session_impl_ == nullptr) { if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice
<< " is available."; << " is available.";
return FAILED; return kMCFailed;
} }
session_impl_->Init(device_id_); session_impl_->Init(device_id_);
init_flag_ = true; init_flag_ = true;
return SUCCESS; return kSuccess;
} }
Status GPUGraphImpl::FinalizeEnv() { Status GPUGraphImpl::FinalizeEnv() {
if (!init_flag_) { if (!init_flag_) {
MS_LOG(WARNING) << "Never initialize before, return success"; MS_LOG(WARNING) << "Never initialize before, return success";
return SUCCESS; return kSuccess;
} }
MS_LOG_INFO << "Start finalize env"; MS_LOG_INFO << "Start finalize env";
@ -77,14 +77,14 @@ Status GPUGraphImpl::FinalizeEnv() {
init_flag_ = false; init_flag_ = false;
MS_LOG(INFO) << "End finalize env"; MS_LOG(INFO) << "End finalize env";
return SUCCESS; return kSuccess;
} }
Status GPUGraphImpl::Load() { Status GPUGraphImpl::Load() {
// check graph type // check graph type
if (graph_->ModelType() != ModelType::kMindIR) { if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
return INVALID_INPUTS; return kMCInvalidInput;
} }
const auto &graph_data = GraphImpl::MutableGraphData(); const auto &graph_data = GraphImpl::MutableGraphData();
@ -93,38 +93,38 @@ Status GPUGraphImpl::Load() {
// init // init
Status ret = InitEnv(); Status ret = InitEnv();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed."; MS_LOG(ERROR) << "InitEnv failed.";
return FAILED; return kMCDeviceError;
} }
ret = CompileGraph(func_graph); ret = CompileGraph(func_graph);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Compile graph model failed"; MS_LOG(ERROR) << "Compile graph model failed";
return FAILED; return kMCFailed;
} }
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) { if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed"; MS_LOG_ERROR << "Get model inputs info failed";
return FAILED; return kMCInvalidInput;
} }
if (outputs_.empty() || outputs_.size() != output_names_.size()) { if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed"; MS_LOG_ERROR << "Get model outputs info failed";
return FAILED; return kMCInvalidInput;
} }
load_flag_ = true; load_flag_ = true;
return SUCCESS; return kSuccess;
} }
Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr); MS_ASSERT(session_impl_ != nullptr);
try { try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return SUCCESS; return kSuccess;
} catch (std::exception &e) { } catch (std::exception &e) {
MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
return FAILED; return kMCFailed;
} }
} }
@ -139,118 +139,118 @@ std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::
} }
} }
Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { Status GPUGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
MS_EXCEPTION_IF_NULL(reply); MS_EXCEPTION_IF_NULL(reply);
vector<tensor::TensorPtr> inputs; vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) { for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i]; auto &item = request[i];
auto input = inputs_[i]; auto input = inputs_info_[i];
if (input->Size() != item.DataSize()) { if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size(); << input->Size();
return FAILED; return kMCInvalidInput;
} }
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); auto ret = memcpy_s(input->data_c(), input->Size(), item.Data().get(), item.DataSize());
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Tensor copy failed"; MS_LOG(ERROR) << "Tensor copy failed";
return FAILED; return kMCFailed;
} }
inputs.push_back(input); inputs.push_back(input);
} }
vector<tensor::TensorPtr> outputs = RunGraph(inputs); last_inputs_ = inputs;
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) { if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed"; MS_LOG(ERROR) << "Execute Model Failed";
return FAILED; return kMCFailed;
} }
last_outputs_ = outputs;
reply->clear(); reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), *reply = GetOutputs();
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); return kSuccess;
return SUCCESS;
} }
Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return ret;
} }
} }
if (inputs.size() != inputs_.size()) { if (inputs.size() != inputs_info_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
return INVALID_INPUTS; << inputs.size();
return kMCInvalidInput;
} }
for (size_t i = 0; i < inputs_.size(); ++i) { for (size_t i = 0; i < inputs_info_.size(); ++i) {
if (inputs[i].DataSize() != inputs_[i]->Size()) { if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
<< inputs[i].DataSize(); << ", given count " << inputs[i].DataSize();
return INVALID_INPUTS; return kMCInvalidInput;
} }
} }
if (ExecuteModel(inputs, outputs) != SUCCESS) { if (ExecuteModel(inputs, outputs) != kSuccess) {
MS_LOG(ERROR) << "Execute Model Failed"; MS_LOG(ERROR) << "Execute Model Failed";
return FAILED; return kMCFailed;
} }
if (outputs_.size() != outputs->size()) { if (outputs_info_.size() != outputs->size()) {
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
<< outputs_.size(); << outputs_info_.size();
return FAILED; return kMCFailed;
} }
return SUCCESS; return kSuccess;
} }
Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GPUGraphImpl::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return {};
} }
} }
GraphUtils::ClearIfNotNull(names); std::vector<MSTensor> result(inputs_info_.size());
GraphUtils::ClearIfNotNull(shapes); for (size_t i = 0; i < inputs_info_.size(); ++i) {
GraphUtils::ClearIfNotNull(data_types); auto &tensor = inputs_info_[i];
GraphUtils::ClearIfNotNull(mem_sizes); void *data = nullptr;
for (size_t i = 0; i < inputs_.size(); i++) { size_t data_size = tensor->Size();
auto &tensor = inputs_[i]; if (i < last_inputs_.size()) {
GraphUtils::PushbackIfNotNull(names, input_names_[i]); data = last_inputs_[i]->data_c();
GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); data_size = last_inputs_[i]->Size();
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); }
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); result[i] =
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
} }
return SUCCESS; return result;
} }
Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
if (!load_flag_) { if (!load_flag_) {
Status ret = Load(); Status ret = Load();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed."; MS_LOG(ERROR) << "PrepareModel failed.";
return ret; return {};
} }
} }
GraphUtils::ClearIfNotNull(names); std::vector<MSTensor> result(outputs_info_.size());
GraphUtils::ClearIfNotNull(shapes); for (size_t i = 0; i < outputs_info_.size(); ++i) {
GraphUtils::ClearIfNotNull(data_types); auto &tensor = outputs_info_[i];
GraphUtils::ClearIfNotNull(mem_sizes); void *data = nullptr;
for (size_t i = 0; i < outputs_.size(); i++) { size_t data_size = tensor->Size();
auto &tensor = outputs_[i]; if (i < last_outputs_.size()) {
GraphUtils::PushbackIfNotNull(names, output_names_[i]); data = last_outputs_[i]->data_c();
GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); data_size = last_outputs_[i]->Size();
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); }
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); result[i] =
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
} }
return result;
return SUCCESS;
} }
} // namespace mindspore
} // namespace mindspore::api

View File

@ -25,20 +25,17 @@
#include "backend/session/session_basic.h" #include "backend/session/session_basic.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "cxx_api/model/model_impl.h" #include "cxx_api/model/model_impl.h"
#include "cxx_api/graph/graph_utils.h"
namespace mindspore::api { namespace mindspore {
class GPUGraphImpl : public GraphCell::GraphImpl { class GPUGraphImpl : public GraphCell::GraphImpl {
public: public:
GPUGraphImpl(); GPUGraphImpl();
~GPUGraphImpl() override = default; ~GPUGraphImpl() override = default;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override; Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private: private:
Status InitEnv(); Status InitEnv();
@ -46,14 +43,16 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_; std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_; uint32_t graph_id_;
std::string device_type_; std::string device_type_;
uint32_t device_id_; uint32_t device_id_;
std::vector<tensor::TensorPtr> inputs_; std::vector<tensor::TensorPtr> inputs_info_;
std::vector<tensor::TensorPtr> outputs_; std::vector<tensor::TensorPtr> outputs_info_;
std::vector<tensor::TensorPtr> last_inputs_;
std::vector<tensor::TensorPtr> last_outputs_;
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
bool init_flag_; bool init_flag_;
@ -63,5 +62,5 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
uint32_t batch_size_; uint32_t batch_size_;
uint32_t workspace_size_; uint32_t workspace_size_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H

View File

@ -17,15 +17,19 @@
#include "cxx_api/graph/graph_data.h" #include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::api { namespace mindspore {
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {} Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {} Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
Graph::~Graph() {} Graph::~Graph() {}
Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
ModelType Graph::ModelType() const { ModelType Graph::ModelType() const {
MS_EXCEPTION_IF_NULL(graph_data_); MS_EXCEPTION_IF_NULL(graph_data_);
return graph_data_->ModelType(); return graph_data_->ModelType();
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -19,7 +19,7 @@
#include "framework/common/helper/model_helper.h" #include "framework/common/helper/model_helper.h"
#endif #endif
namespace mindspore::api { namespace mindspore {
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type) Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
if (model_type != ModelType::kMindIR) { if (model_type != ModelType::kMindIR) {
@ -72,4 +72,4 @@ Buffer Graph::GraphData::GetOMData() const {
return om_data_; return om_data_;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -24,7 +24,7 @@
#include "include/api/types.h" #include "include/api/types.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
namespace mindspore::api { namespace mindspore {
class Graph::GraphData { class Graph::GraphData {
public: public:
GraphData(); GraphData();
@ -46,5 +46,5 @@ class Graph::GraphData {
Buffer om_data_; Buffer om_data_;
enum ModelType model_type_; enum ModelType model_type_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H

View File

@ -26,7 +26,7 @@
#include "cxx_api/graph/graph_data.h" #include "cxx_api/graph/graph_data.h"
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore::api { namespace mindspore {
class GraphCell::GraphImpl { class GraphCell::GraphImpl {
public: public:
GraphImpl() = default; GraphImpl() = default;
@ -35,17 +35,14 @@ class GraphCell::GraphImpl {
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0; virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load() = 0; virtual Status Load() = 0;
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, virtual std::vector<MSTensor> GetInputs() = 0;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0; virtual std::vector<MSTensor> GetOutputs() = 0;
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
protected: protected:
std::shared_ptr<Graph> graph_; std::shared_ptr<Graph> graph_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

View File

@ -1,63 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
#include <map>
#include <vector>
#include "include/api/types.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
class GraphUtils {
public:
static DataType TransTypeId2InferDataType(TypeId type_id) {
const std::map<TypeId, api::DataType> id2type_map{
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
};
auto it = id2type_map.find(type_id);
if (it != id2type_map.end()) {
return it->second;
}
MS_LOG(WARNING) << "Unsupported data id " << type_id;
return api::kMsUnknown;
}
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H

View File

@ -16,47 +16,53 @@
#include "cxx_api/model/acl/acl_model.h" #include "cxx_api/model/acl/acl_model.h"
#include <memory> #include <memory>
#include "include/api/context.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "cxx_api/python_utils.h"
namespace mindspore::api { namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModel); API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
Status AclModel::Build(const std::map<std::string, std::string> &options_map) { Status AclModel::Build() {
MS_LOG(INFO) << "Start build model."; MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(options_map);
std::string options_str = GenerateOptionsStr(options_map); if (graph_cell_ != nullptr) {
MS_EXCEPTION_IF_NULL(options);
if (graph_cell_ != nullptr && options_str == options_str_) {
MS_LOG(INFO) << "This model has been built, skip."; MS_LOG(INFO) << "This model has been built, skip.";
return SUCCESS; return kSuccess;
} }
if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) { if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) {
MS_LOG(INFO) << "Note: Load om model and all build options will be ignored.";
graph_cell_ = std::make_shared<GraphCell>(graph_); graph_cell_ = std::make_shared<GraphCell>(graph_);
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
if (!options_map.empty()) { return kSuccess;
MS_LOG(WARNING) << "All build options will be ignored."; }
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
MS_EXCEPTION_IF_NULL(options);
std::string options_key = options->GenAclOptionsKey();
std::shared_ptr<Graph> graph;
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
MS_LOG(INFO) << "This options has been built, read cache.";
graph = iter->second;
} else {
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
model_converter_.set_options(options.get());
auto om_data = model_converter_.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
} }
return SUCCESS; graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
dynamic_size_graph_map_[options_key] = graph;
} }
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
model_converter_.set_options(options.get());
auto om_data = model_converter_.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return FAILED;
}
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph); auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell); MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell); auto ret = ModelImpl::Load(graph_cell);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed."; MS_LOG(ERROR) << "Load failed.";
return ret; return ret;
} }
@ -64,64 +70,97 @@ Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
// save result // save result
graph_cell_ = graph_cell; graph_cell_ = graph_cell;
options_ = std::move(options); options_ = std::move(options);
options_str_ = options_str;
MS_LOG(INFO) << "Build model success."; MS_LOG(INFO) << "Build model success.";
return SUCCESS; return kSuccess;
} }
Status AclModel::Train(const DataSet &, std::map<std::string, Buffer> *) { Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(INFO) << "Start to resize model.";
return FAILED; MS_EXCEPTION_IF_NULL(graph_);
if (graph_->ModelType() == ModelType::kOM) {
MS_LOG(ERROR) << "OM model is not supported to resize model.";
return kMCFailed;
}
auto origin_inputs = GetInputs();
if (inputs.size() != origin_inputs.size()) {
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
return kMCInvalidInput;
}
if (inputs.size() != dims.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
return kMCInvalidInput;
}
if (model_context_ == nullptr) {
model_context_ = std::make_shared<ModelContext>();
}
std::string input_shape_option;
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].Name() != origin_inputs[i].Name()) {
MS_LOG(ERROR) << "Invalid inputs " << i << " name " << inputs[i].Name() << " not match model input name "
<< origin_inputs[i].Name();
return kMCInvalidInput;
}
input_shape_option += inputs[i].Name() + ":";
for (size_t j = 0; j < dims[i].size(); ++j) {
input_shape_option += std::to_string(dims[i][j]);
if (j + 1 < dims[i].size()) {
input_shape_option += ",";
}
}
if (i + 1 < inputs.size()) {
input_shape_option += ";";
}
}
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
ModelContext::SetInputShape(model_context_, input_shape_option);
auto graph_cell_bak = std::move(graph_cell_);
auto ret = Build();
if (ret != kSuccess) {
MS_LOG(INFO) << "Resize build failed.";
graph_cell_ = std::move(graph_cell_bak);
return ret;
}
MS_LOG(INFO) << "Resize success.";
return kSuccess;
} }
Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) { Status AclModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status AclModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
if (graph_ == nullptr) { if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid data, graph_ is null."; MS_LOG(ERROR) << "Invalid data, graph_ is null.";
return FAILED; return kMCFailed;
} }
if (graph_cell_ == nullptr) { if (graph_cell_ == nullptr) {
MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; MS_LOG(WARNING) << "Model has not been built, it will be built with default options";
Status ret = Build({}); Status ret = Build();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed."; MS_LOG(ERROR) << "Build model failed.";
return FAILED; return ret;
} }
} }
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
Status ret = graph_cell_->Run(inputs, outputs); Status ret = graph_cell_->Run(inputs, outputs);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed."; MS_LOG(ERROR) << "Run graph failed.";
return FAILED; return ret;
} }
return SUCCESS; return kSuccess;
} }
Status AclModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclModel::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); return graph_cell_->GetInputs();
} }
Status AclModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> AclModel::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); return graph_cell_->GetOutputs();
} }
} // namespace mindspore
std::string AclModel::GenerateOptionsStr(const std::map<std::string, std::string> &options) {
std::string ret;
for (auto &[key, value] : options) {
ret += key + "^" + value + "^^";
}
return ret;
}
} // namespace mindspore::api

View File

@ -31,30 +31,25 @@
#include "ir/tensor.h" #include "ir/tensor.h"
#include "ir/anf.h" #include "ir/anf.h"
namespace mindspore::api { namespace mindspore {
class AclModel : public ModelImpl { class AclModel : public ModelImpl {
public: public:
AclModel() : model_converter_(), options_(nullptr), options_str_() {} AclModel() : model_converter_(), options_(nullptr) {}
~AclModel() = default; ~AclModel() = default;
Status Build(const std::map<std::string, std::string> &options_map) override; Status Build() override;
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
private: private:
static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options);
std::shared_ptr<GraphCell> graph_cell_; std::shared_ptr<GraphCell> graph_cell_;
ModelConverter model_converter_; ModelConverter model_converter_;
std::unique_ptr<AclModelOptions> options_; std::unique_ptr<AclModelOptions> options_;
std::string options_str_; std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H #endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H

View File

@ -18,23 +18,31 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
namespace mindspore::api { namespace mindspore {
static std::string ParseOption(const std::map<std::string, std::string> &options, const std::string &key) { static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
auto iter = options.find(key); {DataType::kNumberTypeFloat32, "FP32"},
if (iter != options.end()) { {DataType::kNumberTypeUInt8, "UINT8"}};
return iter->second;
}
return "";
}
AclModelOptions::AclModelOptions(const std::map<std::string, std::string> &options) { AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
// to acl if (context == nullptr) {
insert_op_cfg_path = ParseOption(options, kModelOptionInsertOpCfgPath); return;
input_format = ParseOption(options, kModelOptionInputFormat); }
input_shape = ParseOption(options, kModelOptionInputShape); insert_op_cfg_path = ModelContext::GetInsertOpConfigPath(context);
output_type = ParseOption(options, kModelOptionOutputType); input_format = ModelContext::GetInputFormat(context);
precision_mode = ParseOption(options, kModelOptionPrecisionMode); input_shape = ModelContext::GetInputShape(context);
op_select_impl_mode = ParseOption(options, kModelOptionOpSelectImplMode);
auto out_type = ModelContext::GetOutputType(context);
auto iter = kSupportedDtypeOptionMap.find(out_type);
if (out_type == DataType::kTypeUnknown) {
// do nothing
} else if (iter == kSupportedDtypeOptionMap.end()) {
MS_LOG(WARNING) << "Unsupported output type " << out_type << ", use FP32 as default.";
} else {
output_type = iter->second;
}
precision_mode = ModelContext::GetPrecisionMode(context);
op_select_impl_mode = ModelContext::GetOpSelectImplMode(context);
} }
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions() std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
@ -69,4 +77,16 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
} }
return {init_options, build_options}; return {init_options, build_options};
} }
} // namespace mindspore::api
std::string AclModelOptions::GenAclOptionsKey() const {
auto [init_options, build_options] = GenAclOptions();
std::string key_str;
for (auto &[key, value] : init_options) {
key_str += key + "^" + value + "^^";
}
for (auto &[key, value] : build_options) {
key_str += key + "^" + value + "^^";
}
return key_str;
}
} // namespace mindspore

View File

@ -20,12 +20,13 @@
#include <string> #include <string>
#include <map> #include <map>
#include <tuple> #include <tuple>
#include <memory>
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore::api { namespace mindspore {
struct AclModelOptions { struct AclModelOptions {
std::string output_node; // todo: at convert.cc::BuildGraph(), no atc options
// build options // build options
std::string insert_op_cfg_path; std::string insert_op_cfg_path;
std::string input_format; std::string input_format;
@ -35,12 +36,13 @@ struct AclModelOptions {
std::string op_select_impl_mode; std::string op_select_impl_mode;
std::string soc_version = "Ascend310"; std::string soc_version = "Ascend310";
explicit AclModelOptions(const std::map<std::string, std::string> &options); explicit AclModelOptions(const std::shared_ptr<Context> &context);
~AclModelOptions() = default; ~AclModelOptions() = default;
// return tuple<init_options, build_options> // return tuple<init_options, build_options>
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const; std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
std::string GenAclOptionsKey() const;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H

View File

@ -22,9 +22,8 @@
#include "include/api/serialization.h" #include "include/api/serialization.h"
#include "graph/model.h" #include "graph/model.h"
#include "cxx_api/model/model_converter_utils/multi_process.h" #include "cxx_api/model/model_converter_utils/multi_process.h"
#include "cxx_api/python_utils.h"
namespace mindspore::api { namespace mindspore {
namespace { namespace {
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) { transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
transform::TensorOrderMap res; transform::TensorOrderMap res;
@ -86,25 +85,25 @@ transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &
para->set_name(name); para->set_name(name);
} }
transform::DfGraphConvertor convertor(anf_graph); transform::DfGraphConvertor converter(anf_graph);
std::string net_id = "0"; std::string net_id = "0";
std::string init_graph = "init_subgraph." + net_id; std::string init_graph = "init_subgraph." + net_id;
std::string checkpoint_name = "save." + net_id; std::string checkpoint_name = "save." + net_id;
convertor.set_training(false); converter.set_training(false);
(void)convertor.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph(); (void)converter.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
(void)convertor.GenerateCheckpointGraph(); (void)converter.GenerateCheckpointGraph();
if (convertor.ErrCode() != 0) { if (converter.ErrCode() != 0) {
transform::DfGraphManager::GetInstance().ClearGraph(); transform::DfGraphManager::GetInstance().ClearGraph();
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode();
return nullptr; return nullptr;
} }
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), convertor.GetComputeGraph()); (void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), converter.GetComputeGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); (void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph());
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); (void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph());
transform::Status ret = transform::Status ret =
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph());
if (ret == transform::Status::SUCCESS) { if (ret == transform::Status::SUCCESS) {
transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
} }
@ -158,7 +157,7 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
auto df_graph = ConvertFuncGraphToAIR(func_graph); auto df_graph = ConvertFuncGraphToAIR(func_graph);
if (df_graph == nullptr) { if (df_graph == nullptr) {
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
return FAILED; return kMCFailed;
} }
ge::Model model; ge::Model model;
ge::Buffer model_data; ge::Buffer model_data;
@ -166,14 +165,14 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
auto ge_ret = model.Save(model_data); auto ge_ret = model.Save(model_data);
if (ge_ret != ge::SUCCESS) { if (ge_ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Save ge model to buffer failed."; MS_LOG(ERROR) << "Save ge model to buffer failed.";
return FAILED; return kMCFailed;
} }
// send original model to child // send original model to child
auto status = multi_process->SendMsg(model_data.data(), model_data.size()); auto status = multi_process->SendMsg(model_data.data(), model_data.size());
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Send original model to child process failed"; MS_LOG_ERROR << "Send original model to child process failed";
return FAILED; return status;
} }
// receive convert model result from child // receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
@ -181,11 +180,11 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData()); return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
}; };
status = multi_process->ReceiveMsg(call); status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Receive result model from child process failed"; MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED; return status;
} }
return SUCCESS; return kSuccess;
}; };
auto child_process = [this](MultiProcess *multi_process) -> Status { auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process); MS_EXCEPTION_IF_NULL(multi_process);
@ -196,25 +195,25 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
return reinterpret_cast<uint8_t *>(model.MutableData()); return reinterpret_cast<uint8_t *>(model.MutableData());
}; };
auto status = multi_process->ReceiveMsg(call); auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Receive original model from parent process failed"; MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED; return status;
} }
Buffer model_result = LoadAscendIRInner(model); Buffer model_result = LoadAscendIRInner(model);
if (model_result.DataSize() == 0) { if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from MindIR to OM failed"; MS_LOG_ERROR << "Convert model from MindIR to OM failed";
return FAILED; return kMCFailed;
} }
// send result model to parent // send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Send result model to parent process failed"; MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED; return status;
} }
return SUCCESS; return kSuccess;
}; };
auto status = multi_process.MainProcess(parent_process, child_process); auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Convert MindIR model to OM model failed"; MS_LOG_ERROR << "Convert MindIR model to OM model failed";
} else { } else {
MS_LOG_INFO << "Convert MindIR model to OM model success"; MS_LOG_INFO << "Convert MindIR model to OM model success";
@ -229,9 +228,9 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
MS_EXCEPTION_IF_NULL(multi_process); MS_EXCEPTION_IF_NULL(multi_process);
// send original model to child // send original model to child
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Send original model to child process failed"; MS_LOG_ERROR << "Send original model to child process failed";
return FAILED; return status;
} }
// receive convert model result from child // receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
@ -239,11 +238,11 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData()); return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
}; };
status = multi_process->ReceiveMsg(call); status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Receive result model from child process failed"; MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED; return status;
} }
return SUCCESS; return kSuccess;
}; };
auto child_process = [this](MultiProcess *multi_process) -> Status { auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process); MS_EXCEPTION_IF_NULL(multi_process);
@ -254,25 +253,25 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
return reinterpret_cast<uint8_t *>(model.MutableData()); return reinterpret_cast<uint8_t *>(model.MutableData());
}; };
auto status = multi_process->ReceiveMsg(call); auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Receive original model from parent process failed"; MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED; return status;
} }
Buffer model_result = LoadAscendIRInner(model); Buffer model_result = LoadAscendIRInner(model);
if (model_result.DataSize() == 0) { if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from AIR to OM failed"; MS_LOG_ERROR << "Convert model from AIR to OM failed";
return FAILED; return kMCFailed;
} }
// send result model to parent // send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Send result model to parent process failed"; MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED; return status;
} }
return SUCCESS; return kSuccess;
}; };
auto status = multi_process.MainProcess(parent_process, child_process); auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) { if (status != kSuccess) {
MS_LOG_ERROR << "Convert AIR model to OM model failed"; MS_LOG_ERROR << "Convert AIR model to OM model failed";
} else { } else {
MS_LOG_INFO << "Convert AIR model to OM model success"; MS_LOG_INFO << "Convert AIR model to OM model success";
@ -326,4 +325,4 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
auto om_data = BuildAirModel(df_graph, init_options, build_options); auto om_data = BuildAirModel(df_graph, init_options, build_options);
return om_data; return om_data;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -27,7 +27,7 @@
#include "external/ge/ge_ir_build.h" #include "external/ge/ge_ir_build.h"
#include "cxx_api/model/acl/acl_model_options.h" #include "cxx_api/model/acl/acl_model_options.h"
namespace mindspore::api { namespace mindspore {
class ModelConverter { class ModelConverter {
public: public:
ModelConverter() : options_(nullptr) {} ModelConverter() : options_(nullptr) {}
@ -46,6 +46,5 @@ class ModelConverter {
Buffer LoadMindIRInner(const FuncGraphPtr &func_graph); Buffer LoadMindIRInner(const FuncGraphPtr &func_graph);
Buffer LoadAscendIRInner(const Buffer &model_data); Buffer LoadAscendIRInner(const Buffer &model_data);
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H

View File

@ -19,49 +19,45 @@
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore::api { namespace mindspore {
Status Model::Build(const std::map<std::string, std::string> &options) { Status Model::Build() {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Build(options); return impl_->Build();
} }
Status Model::Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) { Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Train(dataset, outputs); return impl_->Resize(inputs, dims);
} }
Status Model::Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) { Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Eval(dataset, outputs);
}
Status Model::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Predict(inputs, outputs); return impl_->Predict(inputs, outputs);
} }
Status Model::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> Model::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes); return impl_->GetInputs();
} }
Status Model::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> Model::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes); return impl_->GetOutputs();
} }
Model::Model(const GraphCell &graph_cell) Model::Model(const GraphCell &graph_cell, const std::shared_ptr<Context> &model_context)
: impl_(Factory<ModelImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { : impl_(Factory<ModelImpl>::Instance().Create(mindspore::GlobalContext::GetGlobalDeviceTarget())) {
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed"; MS_LOG(EXCEPTION) << "Create session type " << mindspore::GlobalContext::GetGlobalDeviceTarget() << " failed";
} }
MS_EXCEPTION_IF_NULL(graph_cell.GetGraph()); MS_EXCEPTION_IF_NULL(graph_cell.GetGraph());
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph())); impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
impl_->SetContext(model_context);
} }
Model::Model(const std::vector<Output> &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; } Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context) {
MS_LOG(EXCEPTION) << "Unsupported feature.";
}
Model::~Model() {} Model::~Model() {}
@ -69,4 +65,4 @@ bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type); return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -24,7 +24,6 @@
#include "cxx_api/model/model_converter_utils/shared_memory.h" #include "cxx_api/model/model_converter_utils/shared_memory.h"
namespace mindspore { namespace mindspore {
namespace api {
namespace { namespace {
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
} }
@ -40,7 +39,7 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
memory_size_ = kSharedMemorySize; // 100 MB memory_size_ = kSharedMemorySize; // 100 MB
SharedMemory shared_memory; SharedMemory shared_memory;
ret = shared_memory.Create(memory_size_); ret = shared_memory.Create(memory_size_);
if (!ret.IsSuccess()) { if (ret != kSuccess) {
MS_LOG_ERROR << "Create shared memory failed"; MS_LOG_ERROR << "Create shared memory failed";
return ret; return ret;
} }
@ -48,10 +47,10 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
if (pid < 0) { if (pid < 0) {
shared_memory.Destroy(); shared_memory.Destroy();
MS_LOG_ERROR << "Fork process to convert model failed"; MS_LOG_ERROR << "Fork process to convert model failed";
return FAILED; return kMEFailed;
} }
ret = shared_memory.Attach(); ret = shared_memory.Attach();
if (!ret.IsSuccess()) { if (ret != kSuccess) {
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid; MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
return ret; return ret;
} }
@ -87,12 +86,12 @@ Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
Status ret; Status ret;
try { try {
ret = parent_process(this); ret = parent_process(this);
if (!ret.IsSuccess()) { if (ret != kSuccess) {
MS_LOG_ERROR << "Parent process process failed"; MS_LOG_ERROR << "Parent process process failed";
} }
} catch (const std::runtime_error &ex) { } catch (const std::runtime_error &ex) {
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what(); MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
ret = FAILED; ret = kMEFailed;
} }
stopped_ = true; stopped_ = true;
send_msg_->stop = true; send_msg_->stop = true;
@ -108,7 +107,7 @@ void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
try { try {
auto ret = child_process(this); auto ret = child_process(this);
if (!ret.IsSuccess()) { if (ret != kSuccess) {
MS_LOG_ERROR << "Child process process failed"; MS_LOG_ERROR << "Child process process failed";
} }
} catch (const std::runtime_error &ex) { } catch (const std::runtime_error &ex) {
@ -138,14 +137,14 @@ Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
} }
if (peer_stopped_) { if (peer_stopped_) {
if (!send_msg_->read_finish_flag) { if (!send_msg_->read_finish_flag) {
return FAILED; return kMEFailed;
} }
break; break;
} }
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
} }
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len; MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
return SUCCESS; return kSuccess;
} }
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) { Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
@ -158,7 +157,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
usleep(1000); // 1ms usleep(1000); // 1ms
} }
if (peer_stopped_) { if (peer_stopped_) {
return FAILED; return kMEFailed;
} }
if (msg_buffer == nullptr) { if (msg_buffer == nullptr) {
msg_len = receive_msg_->msg_total_len; msg_len = receive_msg_->msg_total_len;
@ -170,7 +169,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
receive_msg_->read_finish_flag = true; receive_msg_->read_finish_flag = true;
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl; MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
} while (msg_len > cur_offset); } while (msg_len > cur_offset);
return SUCCESS; return kSuccess;
} }
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); } void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
@ -200,6 +199,4 @@ void MultiProcess::HeartbeatThreadFuncInner() {
usleep(100000); // sleep 100 ms usleep(100000); // sleep 100 ms
} }
} }
} // namespace api
} // namespace mindspore } // namespace mindspore

View File

@ -21,7 +21,6 @@
#include "include/api/status.h" #include "include/api/status.h"
namespace mindspore { namespace mindspore {
namespace api {
struct MessageFlag { struct MessageFlag {
uint64_t heartbeat = 0; uint64_t heartbeat = 0;
uint64_t stop = false; uint64_t stop = false;
@ -60,7 +59,5 @@ class MultiProcess {
Status ParentProcess(ProcessFuncCall parent_process); Status ParentProcess(ProcessFuncCall parent_process);
void ChildProcess(ProcessFuncCall child_process); void ChildProcess(ProcessFuncCall child_process);
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H #endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H

View File

@ -20,26 +20,25 @@
#include "mindspore/core/utils/log_adapter.h" #include "mindspore/core/utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace api {
Status SharedMemory::Create(uint64_t memory_size) { Status SharedMemory::Create(uint64_t memory_size) {
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode); shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
if (shm_id_ == -1) { if (shm_id_ == -1) {
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno); MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
return FAILED; return kMCFailed;
} }
MS_LOG_INFO << "shmget success, shm id " << shm_id_; MS_LOG_INFO << "shmget success, shm id " << shm_id_;
return SUCCESS; return kSuccess;
} }
Status SharedMemory::Attach() { Status SharedMemory::Attach() {
void *shmat_addr = shmat(shm_id_, nullptr, 0); void *shmat_addr = shmat(shm_id_, nullptr, 0);
if (shmat_addr == reinterpret_cast<void *>(-1)) { if (shmat_addr == reinterpret_cast<void *>(-1)) {
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno); MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
return FAILED; return kMCFailed;
} }
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr); shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
return SUCCESS; return kSuccess;
} }
void SharedMemory::Detach() { void SharedMemory::Detach() {
@ -63,5 +62,4 @@ void SharedMemory::Destroy() {
MS_LOG_ERROR << errMsg; MS_LOG_ERROR << errMsg;
} }
} }
} // namespace api
} // namespace mindspore } // namespace mindspore

View File

@ -20,7 +20,6 @@
#include "include/api/status.h" #include "include/api/status.h"
namespace mindspore { namespace mindspore {
namespace api {
class SharedMemory { class SharedMemory {
public: public:
Status Create(uint64_t memory_size); Status Create(uint64_t memory_size);
@ -33,7 +32,5 @@ class SharedMemory {
int shm_id_ = -1; int shm_id_ = -1;
uint8_t *shmat_addr_ = nullptr; uint8_t *shmat_addr_ = nullptr;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H #endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H

View File

@ -21,28 +21,26 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "include/api/context.h"
#include "include/api/model.h" #include "include/api/model.h"
#include "include/api/graph.h" #include "include/api/graph.h"
#include "cxx_api/graph/graph_data.h" #include "cxx_api/graph/graph_data.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
namespace mindspore::api { namespace mindspore {
class ModelImpl { class ModelImpl {
public: public:
ModelImpl() = default; ModelImpl() = default;
virtual ~ModelImpl() = default; virtual ~ModelImpl() = default;
virtual Status Build(const std::map<std::string, std::string> &options) = 0; virtual Status Build() = 0;
virtual Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) = 0;
virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0; virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
virtual Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, virtual std::vector<MSTensor> GetInputs() = 0;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0; virtual std::vector<MSTensor> GetOutputs() = 0;
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
protected: protected:
Status Load(const std::shared_ptr<GraphCell> &graph_cell) { Status Load(const std::shared_ptr<GraphCell> &graph_cell) {
@ -61,11 +59,16 @@ class ModelImpl {
} }
std::shared_ptr<Graph> graph_; std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> model_context_;
private: private:
friend class Model; friend class Model;
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &model_context) {
if (model_context != nullptr) {
model_context_ = std::make_shared<Context>(*model_context);
}
}
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H #endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H

View File

@ -16,18 +16,78 @@
#include "cxx_api/model/ms/ms_model.h" #include "cxx_api/model/ms/ms_model.h"
#include <memory> #include <memory>
#include "include/api/context.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
namespace mindspore { namespace mindspore {
namespace api {
API_FACTORY_REG(ModelImpl, Ascend910, MsModel); API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
API_FACTORY_REG(ModelImpl, GPU, MsModel); API_FACTORY_REG(ModelImpl, GPU, MsModel);
Status MsModel::Build(const std::map<std::string, std::string> &) { static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key;
for (size_t i = 0; i < dims.size(); ++i) {
shape_key += std::to_string(i) + ":";
for (size_t j = 0; j < dims[i].size(); ++j) {
shape_key += std::to_string(dims[i][j]);
if (j + 1 < dims[i].size()) {
shape_key += ",";
}
}
if (i + 1 < dims.size()) {
shape_key += ";";
}
}
return shape_key;
}
std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key = GenerateShapeKey(dims);
if (auto iter = dynamic_size_graph_map_.find(shape_key); iter != dynamic_size_graph_map_.end()) {
MS_LOG(INFO) << "This options has been built, read cache.";
return iter->second;
}
auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &inputs = func_graph->parameters();
if (dims.size() != inputs.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match model inputs size " << inputs.size();
return nullptr;
}
for (size_t i = 0; i < dims.size(); ++i) {
const auto &param = inputs[i];
auto shape_ptr = std::dynamic_pointer_cast<abstract::Shape>(param->Shape());
if (shape_ptr == nullptr) {
MS_LOG(ERROR) << "Inputs " << i << " is not supported to resize, debug string: " << param->DebugString();
return nullptr;
}
shape_ptr->shape() = dims[i];
}
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(func_graph, ModelType::kMindIR));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
return nullptr;
}
dynamic_size_graph_map_[shape_key] = graph_cell;
return graph_cell;
}
Status MsModel::Build() {
MS_LOG(INFO) << "Start build model."; MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
if (graph_cell_ != nullptr) {
MS_LOG(INFO) << "This model has been built, skip.";
return kSuccess;
}
auto func_graph = ModelImpl::GetFuncGraph(); auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -36,7 +96,7 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
auto graph_cell = std::make_shared<GraphCell>(graph); auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell); MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell); auto ret = ModelImpl::Load(graph_cell);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed."; MS_LOG(ERROR) << "Load failed.";
return ret; return ret;
} }
@ -44,55 +104,66 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
// save result // save result
graph_cell_ = graph_cell; graph_cell_ = graph_cell;
MS_LOG(INFO) << "Build model success."; MS_LOG(INFO) << "Build model success.";
return SUCCESS; return kSuccess;
} }
Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) { Status MsModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(INFO) << "Start to resize model";
return FAILED; auto origin_inputs = GetInputs();
if (inputs.size() != origin_inputs.size()) {
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
return kMCInvalidInput;
}
if (inputs.size() != dims.size()) {
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
return kMCInvalidInput;
}
auto graph_cell = GenerateGraphCell(dims);
if (graph_cell == nullptr) {
MS_LOG(ERROR) << "GenerateGraphCell failed.";
return kMCFailed;
}
MS_LOG(INFO) << "Resize model success.";
graph_cell_ = std::move(graph_cell);
return kSuccess;
} }
Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) { Status MsModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status MsModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
if (graph_ == nullptr) { if (graph_ == nullptr) {
MS_LOG(ERROR) << "Invalid data, graph_ is null."; MS_LOG(ERROR) << "Invalid data, graph_ is null.";
return FAILED; return kMCFailed;
} }
if (graph_cell_ == nullptr) { if (graph_cell_ == nullptr) {
MS_LOG(INFO) << "Model has not been built, it will be built with default options"; MS_LOG(INFO) << "Model has not been built, it will be built with default options";
Status ret = Build({}); Status ret = Build();
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed."; MS_LOG(ERROR) << "Build model failed.";
return FAILED; return ret;
} }
} }
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
Status ret = graph_cell_->Run(inputs, outputs); Status ret = graph_cell_->Run(inputs, outputs);
if (ret != SUCCESS) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed."; MS_LOG(ERROR) << "Run graph failed.";
return FAILED; return ret;
} }
return SUCCESS; return kSuccess;
} }
Status MsModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> MsModel::GetInputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); return graph_cell_->GetInputs();
} }
Status MsModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> MsModel::GetOutputs() {
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(graph_cell_); MS_EXCEPTION_IF_NULL(graph_cell_);
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); return graph_cell_->GetOutputs();
} }
} // namespace api
} // namespace mindspore } // namespace mindspore

View File

@ -33,26 +33,24 @@
#endif #endif
namespace mindspore { namespace mindspore {
namespace api {
class MsModel : public ModelImpl { class MsModel : public ModelImpl {
public: public:
MsModel() {} MsModel() {}
~MsModel() = default; ~MsModel() = default;
Status Build(const std::map<std::string, std::string> &options_map) override; Status Build() override;
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, std::vector<MSTensor> GetInputs() override;
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; std::vector<MSTensor> GetOutputs() override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
private: private:
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
std::shared_ptr<GraphCell> graph_cell_; std::shared_ptr<GraphCell> graph_cell_;
std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_;
}; };
} // namespace api
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H

View File

@ -15,7 +15,7 @@
*/ */
#include "include/api/ops/ops.h" #include "include/api/ops/ops.h"
namespace mindspore::api { namespace mindspore {
Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode, Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group) const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group)
: OpCell("Conv2D"), : OpCell("Conv2D"),
@ -35,4 +35,4 @@ Output Conv2D::operator()(const Input &input1, const Input &input2) const {
std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) { std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) {
return {Output(shared_from_this(), inputs, 1)}; return {Output(shared_from_this(), inputs, 1)};
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -29,7 +29,7 @@ namespace py = pybind11;
static std::mutex init_mutex; static std::mutex init_mutex;
static bool Initialized = false; static bool Initialized = false;
namespace mindspore::api { namespace mindspore {
static void RegAllOpFromPython() { static void RegAllOpFromPython() {
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize(); Py_Initialize();
@ -143,4 +143,4 @@ PythonEnvGuard::~PythonEnvGuard() {
FinalizePython(); FinalizePython();
} }
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H #ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H #define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
namespace mindspore::api { namespace mindspore {
void RegAllOp(); void RegAllOp();
bool PythonIsInited(); bool PythonIsInited();
void InitPython(); void InitPython();
@ -30,5 +30,5 @@ class PythonEnvGuard {
private: private:
bool origin_init_status_; bool origin_init_status_;
}; };
} // namespace mindspore::api } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H #endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H

View File

@ -19,7 +19,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "mindspore/core/load_mindir/load_model.h" #include "mindspore/core/load_mindir/load_model.h"
namespace mindspore::api { namespace mindspore {
static Buffer ReadFile(const std::string &file) { static Buffer ReadFile(const std::string &file) {
Buffer buffer; Buffer buffer;
if (file.empty()) { if (file.empty()) {
@ -68,6 +68,22 @@ static Buffer ReadFile(const std::string &file) {
return buffer; return buffer;
} }
Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = nullptr;
try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed.";
}
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) {
return Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
}
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
}
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Buffer data = ReadFile(file); Buffer data = ReadFile(file);
if (data.Data() == nullptr) { if (data.Data() == nullptr) {
@ -77,7 +93,7 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
FuncGraphPtr anf_graph = nullptr; FuncGraphPtr anf_graph = nullptr;
try { try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize()); anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
} catch (std::exception &e) { } catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed."; MS_LOG(EXCEPTION) << "Load MindIR failed.";
} }
@ -90,21 +106,21 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) { Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return FAILED; return kMEFailed;
} }
Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) { Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return FAILED; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) { Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return FAILED; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) { Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return FAILED; return kMEFailed;
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -17,17 +17,20 @@
#include <numeric> #include <numeric>
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "mindspore/core/ir/api_tensor_impl.h"
namespace mindspore::api { namespace mindspore {
const char *kDeviceTypeAscend310 = "Ascend310"; class Buffer::Impl {
const char *kDeviceTypeAscend910 = "Ascend910";
const char *kDeviceTypeGpu = "GPU";
class DataImpl {
public: public:
DataImpl() : data_() {} Impl() : data_() {}
~DataImpl() = default; ~Impl() = default;
DataImpl(const void *data, size_t data_len) { SetData(data, data_len); } Impl(const void *data, size_t data_len) {
if (data != nullptr) {
(void)SetData(data, data_len);
} else {
ResizeData(data_len);
}
}
const void *Data() const { return data_.data(); } const void *Data() const { return data_.data(); }
void *MutableData() { return data_.data(); } void *MutableData() { return data_.data(); }
@ -66,132 +69,162 @@ class DataImpl {
std::vector<uint8_t> data_; std::vector<uint8_t> data_;
}; };
class Buffer::Impl : public DataImpl { class TensorDefaultImpl : public MSTensor::Impl {
public: public:
Impl() : DataImpl() {} TensorDefaultImpl() : buffer_(), name_(), type_(DataType::kTypeUnknown), shape_() {}
~Impl() = default; ~TensorDefaultImpl() override = default;
Impl(const void *data, size_t data_len) : DataImpl(data, data_len) {} TensorDefaultImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
}; size_t data_len)
: buffer_(data, data_len), name_(name), type_(type), shape_(shape) {}
class Tensor::Impl : public DataImpl { const std::string &Name() const override { return name_; }
public: enum DataType DataType() const override { return type_; }
Impl() : DataImpl(), name_(), type_(DataType::kMsUnknown), shape_() {} const std::vector<int64_t> &Shape() const override { return shape_; }
~Impl() = default;
Impl(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: DataImpl(data, data_len), name_(name), type_(type), shape_(shape) {}
const std::string &Name() const { return name_; } std::shared_ptr<const void> Data() const override {
void SetName(const std::string &name) { name_ = name; } return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
api::DataType DataType() const { return type_; }
void SetDataType(api::DataType type) { type_ = type; }
void SetShape(const std::vector<int64_t> &shape) { shape_ = shape; }
const std::vector<int64_t> &Shape() const { return shape_; }
int64_t ElementNum() const {
std::vector<int64_t> shapex = Shape();
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
} }
static int GetTypeSize(api::DataType type) { void *MutableData() override { return buffer_.MutableData(); }
static const std::map<api::DataType, size_t> type_size_map = { size_t DataSize() const override { return buffer_.DataSize(); }
{kMsBool, sizeof(bool)}, {kMsFloat64, sizeof(double)}, {kMsInt8, sizeof(int8_t)},
{kMsUint8, sizeof(uint8_t)}, {kMsInt16, sizeof(int16_t)}, {kMsUint16, sizeof(uint16_t)},
{kMsInt32, sizeof(int32_t)}, {kMsUint32, sizeof(uint32_t)}, {kMsInt64, sizeof(int64_t)},
{kMsUint64, sizeof(uint64_t)}, {kMsFloat16, sizeof(uint16_t)}, {kMsFloat32, sizeof(float)},
};
auto it = type_size_map.find(type);
if (it != type_size_map.end()) {
return it->second;
}
MS_LOG(WARNING) << "Cannot find data type " << type; bool IsDevice() const override { return false; }
return 0;
std::shared_ptr<Impl> Clone() const override {
return std::make_shared<TensorDefaultImpl>(name_, type_, shape_, buffer_.Data(), buffer_.DataSize());
} }
private: private:
Buffer buffer_;
std::string name_; std::string name_;
api::DataType type_; enum DataType type_;
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
}; };
Tensor::Tensor() : impl_(std::make_shared<Impl>()) {} class TensorReferenceImpl : public MSTensor::Impl {
Tensor::Tensor(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data, public:
size_t data_len) TensorReferenceImpl() : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_() {}
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {} ~TensorReferenceImpl() override = default;
Tensor::~Tensor() = default; TensorReferenceImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape) {}
Tensor Tensor::Clone() const { const std::string &Name() const override { return name_; }
enum DataType DataType() const override { return type_; }
const std::vector<int64_t> &Shape() const override { return shape_; }
std::shared_ptr<const void> Data() const override {
return std::shared_ptr<const void>(data_, [](const void *) {});
}
void *MutableData() override { return const_cast<void *>(data_); }
size_t DataSize() const override { return data_size_; }
bool IsDevice() const override { return false; }
std::shared_ptr<Impl> Clone() const override {
return std::make_shared<TensorReferenceImpl>(name_, type_, shape_, data_, data_size_);
}
protected:
const void *data_;
size_t data_size_;
std::string name_;
enum DataType type_;
std::vector<int64_t> shape_;
};
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
return MSTensor(nullptr);
} catch (...) {
MS_LOG(ERROR) << "Unknown error occurred.";
return MSTensor(nullptr);
}
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
return MSTensor(nullptr);
} catch (...) {
MS_LOG(ERROR) << "Unknown error occurred.";
return MSTensor(nullptr);
}
}
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default;
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
MSTensor MSTensor::Clone() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
Tensor ret; MSTensor ret;
ret.impl_ = std::make_shared<Impl>(*impl_); ret.impl_ = impl_->Clone();
return ret; return ret;
} }
const std::string &Tensor::Name() const { const std::string &MSTensor::Name() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Name(); return impl_->Name();
} }
void Tensor::SetName(const std::string &name) { enum DataType MSTensor::DataType() const {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetName(name);
}
DataType Tensor::DataType() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->DataType(); return impl_->DataType();
} }
void Tensor::SetDataType(api::DataType type) { const std::vector<int64_t> &MSTensor::Shape() const {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDataType(type);
}
const std::vector<int64_t> &Tensor::Shape() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Shape(); return impl_->Shape();
} }
void Tensor::SetShape(const std::vector<int64_t> &shape) { int64_t MSTensor::ElementNum() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
impl_->SetShape(shape); const auto &shape = impl_->Shape();
if (shape.empty()) {
// element number of scalar is 1
return 1;
}
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
} }
const void *Tensor::Data() const { std::shared_ptr<const void> MSTensor::Data() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->Data(); return impl_->Data();
} }
void *Tensor::MutableData() { void *MSTensor::MutableData() {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->MutableData(); return impl_->MutableData();
} }
size_t Tensor::DataSize() const { size_t MSTensor::DataSize() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->DataSize(); return impl_->DataSize();
} }
bool Tensor::ResizeData(size_t data_len) { bool MSTensor::IsDevice() const {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->ResizeData(data_len); return impl_->IsDevice();
} }
bool Tensor::SetData(const void *data, size_t data_len) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->SetData(data, data_len);
}
int64_t Tensor::ElementNum() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->ElementNum();
}
int Tensor::GetTypeSize(api::DataType type) { return Impl::GetTypeSize(type); }
Buffer::Buffer() : impl_(std::make_shared<Impl>()) {} Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {} Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
Buffer::~Buffer() = default; Buffer::~Buffer() = default;
@ -227,4 +260,4 @@ bool Buffer::SetData(const void *data, size_t data_len) {
MS_EXCEPTION_IF_NULL(impl_); MS_EXCEPTION_IF_NULL(impl_);
return impl_->SetData(data, data_len); return impl_->SetData(data, data_len);
} }
} // namespace mindspore::api } // namespace mindspore

View File

@ -284,14 +284,7 @@ else()
endif() endif()
add_dependencies(_c_dataengine mindspore_shared_lib) add_dependencies(_c_dataengine mindspore_shared_lib)
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
set(MINDSPORE_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/cxx_api/CMakeFiles/mindspore_shared_lib.dir/objects.a)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib ${MINDSPORE_LINK_OBJECT})
else()
if(ENABLE_ACL)
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
endif()
endif()
if(USE_GLOG) if(USE_GLOG)
target_link_libraries(_c_dataengine PRIVATE mindspore::glog) target_link_libraries(_c_dataengine PRIVATE mindspore::glog)

View File

@ -26,28 +26,13 @@ if(ENABLE_PYTHON)
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS}) target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
endif() endif()
add_library(cpp-API OBJECT
if(ENABLE_ACL) config.cc
add_library(cpp-API OBJECT datasets.cc
config.cc execute.cc
datasets.cc iterator.cc
execute.cc transforms.cc
iterator.cc samplers.cc
minddata_eager.cc text.cc
transforms.cc vision.cc
samplers.cc )
text.cc
vision.cc
)
else()
add_library(cpp-API OBJECT
config.cc
datasets.cc
execute.cc
iterator.cc
transforms.cc
samplers.cc
text.cc
vision.cc
)
endif()

View File

@ -1,142 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/include/de_tensor.h"
#include "minddata/dataset/include/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "mindspore/lite/include/ms_tensor.h"
#include "utils/hashing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#else
#include "mindspore/lite/src/common/log_adapter.h"
#endif
namespace mindspore {
namespace tensor {
MSTensor *DETensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
return new DETensor(data_type, shape);
}
MSTensor *DETensor::CreateTensor(const std::string &path) {
std::shared_ptr<dataset::Tensor> t;
(void)dataset::Tensor::CreateFromFile(path, &t);
return new DETensor(std::move(t));
}
MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) {
std::shared_ptr<dataset::Tensor> t;
// prepare shape info
std::vector<dataset::dsize_t> t_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
(void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type),
static_cast<uint8_t *>(data), &t);
return new DETensor(std::move(t));
}
DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size());
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
dataset::Tensor::CreateEmpty(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type), &this->tensor_impl_);
}
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
MSTensor *DETensor::ConvertToLiteTensor() {
// static MSTensor::CreateTensor is only for the LiteTensor
MSTensor *tensor = CreateTensor(this->data_type(), this->shape());
MS_ASSERT(tensor->Size() == this->Size());
memcpy_s(tensor->MutableData(), tensor->Size(), this->MutableData(), this->Size());
return tensor;
}
std::shared_ptr<dataset::Tensor> DETensor::tensor() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_;
}
TypeId DETensor::data_type() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return dataset::DETypeToMSType(this->tensor_impl_->type());
}
TypeId DETensor::set_data_type(TypeId data_type) {
MS_ASSERT(this->tensor_impl_ != nullptr);
if (data_type != this->data_type()) {
std::shared_ptr<dataset::Tensor> temp;
dataset::Tensor::CreateFromMemory(this->tensor_impl_->shape(), dataset::MSTypeToDEType(data_type),
this->tensor_impl_->GetBuffer(), &temp);
this->tensor_impl_ = temp;
}
return data_type;
}
std::vector<int> DETensor::shape() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
std::vector<dataset::dsize_t> t_shape = this->tensor_impl_->shape().AsVector();
std::vector<int> shape;
shape.reserve(t_shape.size());
std::transform(t_shape.begin(), t_shape.end(), std::back_inserter(shape),
[](dataset::dsize_t s) -> int { return static_cast<int>(s); });
return shape;
}
size_t DETensor::set_shape(const std::vector<int> &shape) {
MS_ASSERT(this->tensor_impl_ != nullptr);
std::vector<dataset::dsize_t> t_shape;
t_shape.reserve(shape.size());
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
dataset::Status rc = this->tensor_impl_->Reshape(dataset::TensorShape(t_shape));
return shape.size();
}
int DETensor::DimensionSize(size_t index) const {
MS_ASSERT(this->tensor_impl_ != nullptr);
int dim_size = -1;
auto shape = this->shape();
if (index < shape.size()) {
dim_size = shape[index];
} else {
MS_LOG(ERROR) << "Dimension index is wrong: " << index;
}
return dim_size;
}
int DETensor::ElementsNum() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->Size();
}
size_t DETensor::Size() const {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->SizeInBytes();
}
void *DETensor::MutableData() {
MS_ASSERT(this->tensor_impl_ != nullptr);
return this->tensor_impl_->GetMutableBuffer();
}
} // namespace tensor
} // namespace mindspore

View File

@ -14,12 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/core/tensor_row.h"
#ifdef ENABLE_ANDROID
#include "minddata/dataset/include/de_tensor.h"
#endif
#include "minddata/dataset/include/execute.h" #include "minddata/dataset/include/execute.h"
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -30,78 +29,85 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} Execute::Execute(std::shared_ptr<TensorOperation> op) { ops_.emplace_back(std::move(op)); }
/// \brief Destructor Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops) : ops_(std::move(ops)) {}
Execute::~Execute() = default;
#ifdef ENABLE_ANDROID Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) { // Validate input tensor
// Build the op CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data");
if (op_ == nullptr) { CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
MS_LOG(ERROR) << "Input TensorOperation is not valid";
return nullptr; // Validate and build runtime ops
std::vector<std::shared_ptr<TensorOp>> transforms;
for (int32_t i = 0; i < ops_.size(); i++) {
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
transforms.emplace_back(ops_[i]->Build());
} }
std::shared_ptr<Tensor> de_input = std::dynamic_pointer_cast<tensor::DETensor>(input)->tensor(); // Convert mindspore::Tensor to dataset::Tensor
if (de_input == nullptr) { std::shared_ptr<dataset::Tensor> de_tensor;
MS_LOG(ERROR) << "Input Tensor is not valid"; Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()),
return nullptr; MSTypeToDEType(static_cast<TypeId>(input.DataType())),
} (const uchar *)(input.Data().get()), input.DataSize(), &de_tensor);
std::shared_ptr<TensorOp> transform = op_->Build(); RETURN_IF_NOT_OK(rc);
std::shared_ptr<Tensor> de_output;
Status rc = transform->Compute(de_input, &de_output);
if (rc.IsError()) { // Apply transforms on tensor
// execution failed for (auto &t : transforms) {
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString(); std::shared_ptr<dataset::Tensor> de_output;
return nullptr; RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output));
}
return std::make_shared<tensor::DETensor>(std::move(de_output));
}
#endif
std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) { // For next transform
// Build the op de_tensor = std::move(de_output);
if (op_ == nullptr) {
MS_LOG(ERROR) << "Input TensorOperation is not valid";
return nullptr;
} }
if (input == nullptr) { // Convert dataset::Tensor to mindspore::Tensor
MS_LOG(ERROR) << "Input Tensor is not valid"; CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
return nullptr; *output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
} return Status::OK();
// will add validate params once API is set
std::shared_ptr<TensorOp> transform = op_->Build();
std::shared_ptr<Tensor> de_output;
Status rc = transform->Compute(input, &de_output);
if (rc.IsError()) {
// execution failed
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
return nullptr;
}
return de_output;
} }
Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list, Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) { // Validate input tensor
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid"); CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
for (auto &tensor : input_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data");
}
CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
TensorRow input, output; // Validate and build runtime ops
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input)); std::vector<std::shared_ptr<TensorOp>> transforms;
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid"); for (int32_t i = 0; i < ops_.size(); i++) {
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
std::shared_ptr<TensorOp> transform = op_->Build(); RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
Status rc = transform->Compute(input, &output); transforms.emplace_back(ops_[i]->Build());
if (rc.IsError()) {
// execution failed
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
} }
std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list)); TensorRow de_tensor_list;
for (auto &tensor : input_tensor_list) {
std::shared_ptr<dataset::Tensor> de_tensor;
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor);
RETURN_IF_NOT_OK(rc);
de_tensor_list.emplace_back(std::move(de_tensor));
}
// Apply transforms on tensor
for (auto &t : transforms) {
TensorRow de_output_list;
RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list));
// For next transform
de_tensor_list = std::move(de_output_list);
}
for (auto &tensor : de_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor->HasData(), "Apply transform failed, output tensor has no data");
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
output_tensor_list->emplace_back(ms_tensor);
}
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid");
return Status::OK(); return Status::OK();
} }

View File

@ -1,154 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <unordered_map>
#include "minddata/dataset/include/minddata_eager.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace api {
MindDataEager::MindDataEager(std::vector<std::shared_ptr<dataset::TensorOperation>> ops) : ops_(ops) {}
// Helper function to convert Type from DE to MS
DataType ToMSType(dataset::DataType type) {
switch (dataset::DataType::Type(type)) {
case dataset::DataType::DE_BOOL:
return DataType::kMsBool;
case dataset::DataType::DE_UINT8:
return DataType::kMsUint8;
case dataset::DataType::DE_INT32:
return DataType::kMsInt32;
case dataset::DataType::DE_INT64:
return DataType::kMsInt64;
case dataset::DataType::DE_FLOAT32:
return DataType::kMsFloat32;
default:
return DataType::kMsUnknown;
}
}
// Helper function to convert Type from MS to DE
dataset::DataType ToDEType(DataType type) {
switch (type) {
case DataType::kMsBool:
return dataset::DataType(dataset::DataType::DE_BOOL);
case DataType::kMsUint8:
return dataset::DataType(dataset::DataType::DE_UINT8);
case DataType::kMsInt32:
return dataset::DataType(dataset::DataType::DE_INT32);
case DataType::kMsInt64:
return dataset::DataType(dataset::DataType::DE_INT64);
case DataType::kMsFloat32:
return dataset::DataType(dataset::DataType::DE_FLOAT32);
default:
return dataset::DataType(dataset::DataType::DE_UNKNOWN);
}
}
Status MindDataEager::LoadImageFromDir(const std::string &image_dir, std::vector<std::shared_ptr<Tensor>> *images) {
// Check target directory
dataset::Path image_dir_(image_dir);
if (!image_dir_.Exists() || !image_dir_.IsDirectory()) {
std::string err_msg = "Target directory: " + image_dir + " does not exist or not a directory.";
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
if (access(image_dir_.toString().c_str(), R_OK) == -1) {
std::string err_msg = "No access to target directory: " + image_dir;
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
// Start reading images and constructing tensors
auto path_itr = dataset::Path::DirIterator::OpenDirectory(&image_dir_);
while (path_itr->hasNext()) {
dataset::Path file = path_itr->next();
std::shared_ptr<dataset::Tensor> image;
dataset::Tensor::CreateFromFile(file.toString(), &image);
std::shared_ptr<Tensor> ms_image = std::make_shared<Tensor>("image", DataType(kMsUint8), image->shape().AsVector(),
image->GetBuffer(), image->SizeInBytes());
images->push_back(ms_image);
}
// Check if read images or not
if (images->empty()) {
std::string err_msg = "No images found in target directory: " + image_dir;
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::FAILED, err_msg);
}
return Status(StatusCode::SUCCESS);
}
std::shared_ptr<Tensor> MindDataEager::operator()(std::shared_ptr<Tensor> input) {
// Validate ops
if (ops_.empty()) {
MS_LOG(ERROR) << "Input TensorOperation should be provided";
return nullptr;
}
for (int32_t i = 0; i < ops_.size(); i++) {
if (ops_[i] == nullptr) {
MS_LOG(ERROR) << "Input TensorOperation[" << i << "] is invalid or null";
return nullptr;
}
}
// Validate input tensor
if (input == nullptr) {
MS_LOG(ERROR) << "Input Tensor should not be null";
return nullptr;
}
// Start applying transforms in ops
std::shared_ptr<dataset::Tensor> de_input;
dataset::Tensor::CreateFromMemory(dataset::TensorShape(input->Shape()), ToDEType(input->DataType()),
(const uchar *)(input->Data()), &de_input);
for (int32_t i = 0; i < ops_.size(); i++) {
// Build runtime op and run
std::shared_ptr<dataset::Tensor> de_output;
std::shared_ptr<dataset::TensorOp> transform = ops_[i]->Build();
dataset::Status rc = transform->Compute(de_input, &de_output);
// check execution failed
if (rc.IsError()) {
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
return nullptr;
}
// For next transform
de_input = std::move(de_output);
}
// Convert DETensor to Tensor
if (!de_input->HasData()) {
MS_LOG(ERROR) << "Apply transform failed, output tensor has no data";
return nullptr;
}
std::shared_ptr<Tensor> output =
std::make_shared<Tensor>("transfomed", ToMSType(de_input->type()), de_input->shape().AsVector(),
de_input->GetBuffer(), de_input->SizeInBytes());
return output;
}
} // namespace api
} // namespace mindspore

View File

@ -29,25 +29,42 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
return execute; return execute;
})) }))
.def("__call__", .def("__call__",
[](Execute &self, std::shared_ptr<Tensor> in) { [](Execute &self, const std::shared_ptr<Tensor> &de_tensor) {
std::shared_ptr<Tensor> out = self(in); auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
if (out == nullptr) { Status rc = self(ms_tensor, &ms_tensor);
THROW_IF_ERROR([]() { if (rc.IsError()) {
RETURN_STATUS_UNEXPECTED( THROW_IF_ERROR([&rc]() {
"Failed to execute op in eager mode, please check ERROR log above."); RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString());
}()); }());
} }
return out; std::shared_ptr<dataset::Tensor> de_output_tensor;
dataset::Tensor::CreateFromMemory(dataset::TensorShape(ms_tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(ms_tensor.DataType())),
(const uchar *)(ms_tensor.Data().get()),
ms_tensor.DataSize(), &de_output_tensor);
return de_output_tensor;
}) })
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) { .def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
std::vector<std::shared_ptr<Tensor>> output_tensor_list; std::vector<MSTensor> ms_input_tensor_list;
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list)); std::vector<MSTensor> ms_output_tensor_list;
if (output_tensor_list.empty()) { for (auto &tensor : input_tensor_list) {
THROW_IF_ERROR([]() { auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above."); ms_input_tensor_list.emplace_back(std::move(ms_tensor));
}());
} }
return output_tensor_list; Status rc = self(ms_input_tensor_list, &ms_output_tensor_list);
if (rc.IsError()) {
THROW_IF_ERROR(
[&rc]() { RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); }());
}
std::vector<std::shared_ptr<dataset::Tensor>> de_output_tensor_list;
for (auto &tensor : ms_output_tensor_list) {
std::shared_ptr<dataset::Tensor> de_output_tensor;
dataset::Tensor::CreateFromMemory(
dataset::TensorShape(tensor.Shape()), MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_output_tensor);
de_output_tensor_list.emplace_back(std::move(de_output_tensor));
}
return de_output_tensor_list;
}); });
})); }));
} // namespace dataset } // namespace dataset

View File

@ -84,7 +84,8 @@ PYBIND_REGISTER(SliceOption, 0, ([](const py::module *m) {
} }
if (!c_slice.valid()) { if (!c_slice.valid()) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); THROW_IF_ERROR(
Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
} }
return SliceOption(c_slice); return SliceOption(c_slice);
})) }))

View File

@ -354,7 +354,7 @@ PYBIND_REGISTER(
for (auto handle : py_sub.cast<py::list>()) { for (auto handle : py_sub.cast<py::list>()) {
py::tuple tp = handle.cast<py::tuple>(); py::tuple tp = handle.cast<py::tuple>();
if (tp.is_none() || tp.size() != 2) { if (tp.is_none() || tp.size() != 2) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob).")); THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
} }
std::shared_ptr<TensorOperation> t_op; std::shared_ptr<TensorOperation> t_op;
if (py::isinstance<TensorOperation>(tp[0])) { if (py::isinstance<TensorOperation>(tp[0])) {
@ -366,11 +366,11 @@ PYBIND_REGISTER(
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>())); std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
} else { } else {
THROW_IF_ERROR( THROW_IF_ERROR(
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc.")); Status(StatusCode::kMDUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
} }
double prob = (tp[1]).cast<py::float_>(); double prob = (tp[1]).cast<py::float_>();
if (prob < 0 || prob > 1) { if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1].")); THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "prob needs to be with [0,1]."));
} }
cpp_policy.back().emplace_back(std::make_pair(t_op, prob)); cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
} }

View File

@ -51,12 +51,12 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
// Acquire Python GIL // Acquire Python GIL
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
} }
try { try {
f(cb_param); f(cb_param);
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kMDPyFuncException, e.what());
} }
} }
return Status::OK(); return Status::OK();

View File

@ -5,6 +5,7 @@ set(DATASET_CORE_SRC_FILES
config_manager.cc config_manager.cc
cv_tensor.cc cv_tensor.cc
data_type.cc data_type.cc
de_tensor.cc
global_context.cc global_context.cc
tensor.cc tensor.cc
tensor_helpers.cc tensor_helpers.cc

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/include/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "utils/hashing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#define ASSERT_NULL(ptr) MS_EXCEPTION_IF_NULL(ptr)
#else
#include "mindspore/lite/src/common/log_adapter.h"
#define ASSERT_NULL(ptr) MS_ASSERT((ptr) != nullptr)
#endif
namespace mindspore {
namespace dataset {
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_impl)
: tensor_impl_(tensor_impl),
name_("MindDataTensor"),
type_(static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()))),
shape_(tensor_impl_->shape().AsVector()) {}
const std::string &DETensor::Name() const { return name_; }
enum mindspore::DataType DETensor::DataType() const {
ASSERT_NULL(tensor_impl_);
return static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()));
}
size_t DETensor::DataSize() const {
ASSERT_NULL(tensor_impl_);
return tensor_impl_->SizeInBytes();
}
const std::vector<int64_t> &DETensor::Shape() const { return shape_; }
std::shared_ptr<const void> DETensor::Data() const {
return std::shared_ptr<const void>(tensor_impl_->GetBuffer(), [](const void *) {});
}
void *DETensor::MutableData() {
ASSERT_NULL(tensor_impl_);
return tensor_impl_->GetMutableBuffer();
}
bool DETensor::IsDevice() const { return false; }
std::shared_ptr<mindspore::MSTensor::Impl> DETensor::Clone() const { return std::make_shared<DETensor>(tensor_impl_); }
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
#include <string>
#include <vector>
#include <memory>
#include "include/api/types.h"
#include "mindspore/core/ir/api_tensor_impl.h"
#include "minddata/dataset/include/status.h"
#include "minddata/dataset/include/tensor.h"
namespace mindspore {
namespace dataset {
class DETensor : public mindspore::MSTensor::Impl {
public:
DETensor() = default;
~DETensor() override = default;
explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_impl);
const std::string &Name() const override;
enum mindspore::DataType DataType() const override;
size_t DataSize() const override;
const std::vector<int64_t> &Shape() const override;
std::shared_ptr<const void> Data() const override;
void *MutableData() override;
bool IsDevice() const override;
std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override;
private:
std::shared_ptr<dataset::Tensor> tensor_impl_;
std::string name_;
enum mindspore::DataType type_;
std::vector<int64_t> shape_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_

View File

@ -41,23 +41,17 @@
#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor_helpers.h" #include "minddata/dataset/core/tensor_helpers.h"
#include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/core/de_tensor.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "proto/example.pb.h" #include "proto/example.pb.h"
#else
#include "minddata/dataset/include/de_tensor.h"
#endif #endif
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
namespace py = pybind11; namespace py = pybind11;
#endif #endif
namespace mindspore { namespace mindspore {
#ifdef ENABLE_ANDROID
namespace tensor {
class DETensor;
} // namespace tensor
#endif
namespace dataset { namespace dataset {
class Tensor; class Tensor;
template <typename T> template <typename T>
@ -85,7 +79,7 @@ class Tensor {
/// \param other Tensor to be moved /// \param other Tensor to be moved
Tensor(Tensor &&other) noexcept; Tensor(Tensor &&other) noexcept;
/// Move assigment operator /// Move assignment operator
/// \param other Tensor to be moved /// \param other Tensor to be moved
Tensor &operator=(Tensor &&other) noexcept; Tensor &operator=(Tensor &&other) noexcept;
@ -134,7 +128,7 @@ class Tensor {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// Create a tensor of type DE_STRING from a BytesList. /// Create a tensor of type DE_STRING from a BytesList.
/// \param[in] bytes_list protobuf's Bytelist /// \param[in] bytes_list protobuf's Bytelist
/// \param[in] shape shape of the outout tensor /// \param[in] shape shape of the output tensor
/// \param[out] out created Tensor /// \param[out] out created Tensor
/// \return Status Code /// \return Status Code
static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out); static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out);
@ -292,7 +286,7 @@ class Tensor {
std::string err; std::string err;
err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; err += (data_ == nullptr) ? "data_ is nullptr \t" : "";
err += type_.IsCompatible<T>() ? "data type not compatible\t" : ""; err += type_.IsCompatible<T>() ? "data type not compatible\t" : "";
return Status(StatusCode::kUnexpectedError, err); return Status(StatusCode::kMDUnexpectedError, err);
} }
} }
@ -343,7 +337,7 @@ class Tensor {
void Invalidate(); void Invalidate();
/// Copy input tensor into self at the location index. /// Copy input tensor into self at the location index.
/// Index is a vector of axises which can be incomplete: /// Index is a vector of axes which can be incomplete:
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. /// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
/// \param index /// \param index
/// \param input /// \param input
@ -686,9 +680,7 @@ class Tensor {
unsigned char *data_end_ = nullptr; unsigned char *data_end_ = nullptr;
private: private:
#ifdef ENABLE_ANDROID friend class DETensor;
friend class tensor::DETensor;
#endif
/// Slice numeric tensors. /// Slice numeric tensors.
Status SliceNumeric(TensorPtr *out, const std::vector<std::vector<dsize_t>> &indices, const TensorShape &shape); Status SliceNumeric(TensorPtr *out, const std::vector<std::vector<dsize_t>> &indices, const TensorShape &shape);

View File

@ -73,6 +73,7 @@ if(ENABLE_CACHE)
engine-cache-server engine-cache-server
_c_dataengine _c_dataengine
_c_mindrecord _c_mindrecord
mindspore
mindspore::protobuf mindspore::protobuf
mindspore::grpc++ mindspore::grpc++
mindspore_gvar mindspore_gvar
@ -85,6 +86,7 @@ if(ENABLE_CACHE)
engine-cache-server engine-cache-server
_c_dataengine _c_dataengine
_c_mindrecord _c_mindrecord
mindspore
mindspore::protobuf mindspore::protobuf
mindspore::grpc++ mindspore::grpc++
mindspore_gvar mindspore_gvar
@ -103,6 +105,7 @@ if(ENABLE_CACHE)
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread) target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread)
target_link_libraries(cache_admin mindspore mindspore_shared_lib)
if(USE_GLOG) if(USE_GLOG)
target_link_libraries(cache_admin mindspore::glog) target_link_libraries(cache_admin mindspore::glog)

View File

@ -22,10 +22,11 @@
#include "minddata/dataset/engine/cache/cache_common.h" #include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset; namespace ds = mindspore::dataset;
int main(int argc, char **argv) { int main(int argc, char **argv) {
ds::Status rc; ms::Status rc;
ds::CacheAdminArgHandler args; ds::CacheAdminArgHandler args;
std::stringstream arg_stream; std::stringstream arg_stream;

View File

@ -89,7 +89,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
ArgValue selected_arg = arg_map_[option]; ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) { if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once."; std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Flag that this arg is used now // Flag that this arg is used now
@ -101,7 +101,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
if (command_id != CommandId::kCmdUnknown) { if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) { if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} else { } else {
command_id_ = command_id; command_id_ = command_id;
} }
@ -113,7 +113,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
*arg_stream >> value_as_string; *arg_stream >> value_as_string;
if (value_as_string.empty()) { if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Now, attempt to convert the value into it's numeric format for output // Now, attempt to convert the value into it's numeric format for output
@ -121,7 +121,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
*out_arg = std::stoul(value_as_string); *out_arg = std::stoul(value_as_string);
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string; std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
return Status::OK(); return Status::OK();
@ -133,7 +133,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
ArgValue selected_arg = arg_map_[option]; ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) { if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once."; std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Flag that this arg is used now // Flag that this arg is used now
@ -145,7 +145,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
if (command_id != CommandId::kCmdUnknown) { if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) { if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} else { } else {
command_id_ = command_id; command_id_ = command_id;
} }
@ -158,12 +158,12 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
*arg_stream >> *out_arg; *arg_stream >> *out_arg;
} else { } else {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
if (out_arg->empty()) { if (out_arg->empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
} }
@ -176,7 +176,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
ArgValue selected_arg = arg_map_[option]; ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) { if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once."; std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Flag that this arg is used now // Flag that this arg is used now
@ -188,7 +188,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
if (command_id != CommandId::kCmdUnknown) { if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) { if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} else { } else {
command_id_ = command_id; command_id_ = command_id;
} }
@ -200,7 +200,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
*arg_stream >> value_as_string; *arg_stream >> value_as_string;
if (value_as_string.empty()) { if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Now, attempt to convert the value into it's string format for output // Now, attempt to convert the value into it's string format for output
@ -208,7 +208,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
*out_arg = std::stof(value_as_string, nullptr); *out_arg = std::stof(value_as_string, nullptr);
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string; std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
return Status::OK(); return Status::OK();
@ -224,7 +224,7 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
if (hostname_ != std::string(kCfgDefaultCacheHost)) { if (hostname_ != std::string(kCfgDefaultCacheHost)) {
std::string err_msg = std::string err_msg =
"Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used."; "Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used.";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
break; break;
} }
@ -304,7 +304,7 @@ Status CacheAdminArgHandler::Validate() {
if (!trailing_args_.empty()) { if (!trailing_args_.empty()) {
std::string err_msg = "Invalid arguments provided: " + trailing_args_; std::string err_msg = "Invalid arguments provided: " + trailing_args_;
err_msg += "\nPlease try `cache_admin --help` for more information"; err_msg += "\nPlease try `cache_admin --help` for more information";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to // The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
@ -312,18 +312,18 @@ Status CacheAdminArgHandler::Validate() {
if (command_id_ == CommandId::kCmdUnknown) { if (command_id_ == CommandId::kCmdUnknown) {
std::string err_msg = "No command provided"; std::string err_msg = "No command provided";
err_msg += "\nPlease try `cache_admin --help` for more information"; err_msg += "\nPlease try `cache_admin --help` for more information";
return Status(StatusCode::kSyntaxError, err_msg); return Status(StatusCode::kMDSyntaxError, err_msg);
} }
// Additional checks here // Additional checks here
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100); auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
if (num_workers_ < 1 || num_workers_ > max_num_workers) if (num_workers_ < 1 || num_workers_ > max_num_workers)
return Status(StatusCode::kSyntaxError, return Status(StatusCode::kMDSyntaxError,
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + "."); "Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..3).");
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); return Status(StatusCode::kMDSyntaxError, "Memory cap ratio should be positive and no greater than 1");
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535)."); if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535).");
return Status::OK(); return Status::OK();
} }
@ -467,9 +467,9 @@ Status CacheAdminArgHandler::StopServer(CommandId command_id) {
Status rc = rq->Wait(); Status rc = rq->Wait();
if (rc.IsError()) { if (rc.IsError()) {
msg.RemoveResourcesOnExit(); msg.RemoveResourcesOnExit();
if (rc.IsNetWorkError()) { if (rc == StatusCode::kMDNetWorkError) {
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already."; std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
return Status(StatusCode::kNetWorkError, errMsg); return Status(StatusCode::kMDNetWorkError, errMsg);
} }
return rc; return rc;
} }
@ -544,7 +544,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
if (WIFEXITED(status)) { if (WIFEXITED(status)) {
auto exit_status = WEXITSTATUS(status); auto exit_status = WEXITSTATUS(status);
if (exit_status) { if (exit_status) {
return Status(StatusCode::kUnexpectedError, msg); return Status(StatusCode::kMDUnexpectedError, msg);
} else { } else {
// Not an error, some info message goes to stdout // Not an error, some info message goes to stdout
std::cout << msg << std::endl; std::cout << msg << std::endl;

View File

@ -75,7 +75,7 @@ Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, vo
do { do {
std::unique_lock<std::mutex> lock(mux_[slot]); std::unique_lock<std::mutex> lock(mux_[slot]);
rc = shm_pool_[slot]->Allocate(sz, p); rc = shm_pool_[slot]->Allocate(sz, p);
if (rc.IsOutofMemory()) { if (rc == StatusCode::kMDOutOfMemory) {
slot = (slot + 1) % shm_pool_.size(); slot = (slot + 1) % shm_pool_.size();
} }
} while (rc.IsError() && slot != begin_slot); } while (rc.IsError() && slot != begin_slot);

View File

@ -137,7 +137,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
Status CacheClient::AsyncWriteRow(const TensorRow &row) { Status CacheClient::AsyncWriteRow(const TensorRow &row) {
if (async_buffer_stream_ == nullptr) { if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet); return Status(StatusCode::kMDNotImplementedYet);
} }
RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row)); RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
return Status::OK(); return Status::OK();
@ -145,7 +145,7 @@ Status CacheClient::AsyncWriteRow(const TensorRow &row) {
Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) { Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
if (async_buffer_stream_ == nullptr) { if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet); return Status(StatusCode::kMDNotImplementedYet);
} else { } else {
Status rc; Status rc;
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>(); std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
@ -155,7 +155,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
TensorRow row; TensorRow row;
RETURN_IF_NOT_OK(in->PopRow(&row)); RETURN_IF_NOT_OK(in->PopRow(&row));
rc = AsyncWriteRow(row); rc = AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) { if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
tensor_table->push_back(row); tensor_table->push_back(row);
} else if (rc.IsError()) { } else if (rc.IsError()) {
return rc; return rc;
@ -165,7 +165,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
// If not all of them can be sent async, return what's left back to the caller. // If not all of them can be sent async, return what's left back to the caller.
if (!tensor_table->empty()) { if (!tensor_table->empty()) {
in->set_tensor_table(std::move(tensor_table)); in->set_tensor_table(std::move(tensor_table));
return Status(StatusCode::kNotImplementedYet); return Status(StatusCode::kMDNotImplementedYet);
} }
} }
return Status::OK(); return Status::OK();
@ -225,7 +225,8 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
auto cache_state = static_cast<CacheServiceState>(out); auto cache_state = static_cast<CacheServiceState>(out);
if (cache_state == CacheServiceState::kFetchPhase || if (cache_state == CacheServiceState::kFetchPhase ||
(cache_state == CacheServiceState::kBuildPhase && cookie_.empty())) { (cache_state == CacheServiceState::kBuildPhase && cookie_.empty())) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
"Not an error and we should bypass the build phase");
} }
} else { } else {
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
@ -243,10 +244,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag); auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq)); RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait(); Status rc = rq->Wait();
bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey); bool success = (rc.IsOk() || rc.StatusCode() == StatusCode::kMDDuplicateKey);
// If we get kDuplicateKey, it just means we aren't the first one to create the cache, // If we get kDuplicateKey, it just means we aren't the first one to create the cache,
// and we will continue to parse the result. // and we will continue to parse the result.
if (rc.get_code() == StatusCode::kDuplicateKey) { if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
RETURN_IF_NOT_OK(rq->PostReply()); RETURN_IF_NOT_OK(rq->PostReply());
} }
if (success) { if (success) {
@ -443,7 +444,7 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
} }
// If the size is too big, tell the user to send it directly. // If the size is too big, tell the user to send it directly.
if (sz > kAsyncBufferSize) { if (sz > kAsyncBufferSize) {
return Status(StatusCode::kNotImplementedYet); return Status(StatusCode::kMDNotImplementedYet);
} }
std::unique_lock<std::mutex> lock(mux_); std::unique_lock<std::mutex> lock(mux_);
// Check error from the server side while we have the lock; // Check error from the server side while we have the lock;

View File

@ -66,7 +66,7 @@ enum class CacheServiceState : int8_t {
/// \param rc[in] Status object /// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object /// \param reply[in/out] pointer to pre-allocated protobuf object
inline void Status2CacheReply(const Status &rc, CacheReply *reply) { inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<int32_t>(rc.get_code())); reply->set_rc(static_cast<int32_t>(rc.StatusCode()));
reply->set_msg(rc.ToString()); reply->set_msg(rc.ToString());
} }
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number

View File

@ -98,7 +98,7 @@ Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffer
(*out_fbb) = std::move(fbb); (*out_fbb) = std::move(fbb);
return Status::OK(); return Status::OK();
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
} }

View File

@ -95,7 +95,7 @@ Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
std::unique_lock<std::mutex> lck(mux_); std::unique_lock<std::mutex> lck(mux_);
auto r = req_.emplace(seqNo, std::move(tag)); auto r = req_.emplace(seqNo, std::move(tag));
if (!r.second) { if (!r.second) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__);
} }
} }
// Last step is to tag the request. // Last step is to tag the request.
@ -124,7 +124,7 @@ Status CacheClientGreeter::WorkerEntry() {
} else { } else {
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
} }
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg); Status remote_rc = Status(StatusCode::kMDNetWorkError, __LINE__, __FILE__, err_msg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_); Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
} }
// Notify the waiting thread. // Notify the waiting thread.

View File

@ -25,7 +25,7 @@ Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
shmkey = ftok(unix_path.data(), 'a'); shmkey = ftok(unix_path.data(), 'a');
if (shmkey == (key_t)-1) { if (shmkey == (key_t)-1) {
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno); std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg); return Status(errno == ENOENT ? StatusCode::kMDFileNotExist : StatusCode::kMDUnexpectedError, errMsg);
} }
*out = shmkey; *out = shmkey;
return Status::OK(); return Status::OK();
@ -56,7 +56,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
CacheMsgBuf msg{ CacheMsgBuf msg{
1, 1,
}; };
msg.body.status.err_code = static_cast<int32_t>(rc.get_code()); msg.body.status.err_code = static_cast<int32_t>(rc.StatusCode());
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size()); auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err)); CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
msg.body.status.err_msg[rc.ToString().size()] = '\0'; msg.body.status.err_msg[rc.ToString().size()] = '\0';

View File

@ -25,16 +25,17 @@
#include <chrono> #include <chrono>
#include "minddata/dataset/engine/cache/cache_common.h" #include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/engine/cache/cache_ipc.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset; namespace ds = mindspore::dataset;
/// Start the server /// Start the server
/// \param argv /// \param argv
/// \return Status object /// \return Status object
ds::Status StartServer(int argc, char **argv) { ms::Status StartServer(int argc, char **argv) {
ds::Status rc; ms::Status rc;
ds::CacheServer::Builder builder; ds::CacheServer::Builder builder;
if (argc != 8) { if (argc != 8) {
return ds::Status(ds::StatusCode::kSyntaxError); return ms::Status(ms::StatusCode::kMDSyntaxError);
} }
int32_t port = strtol(argv[3], nullptr, 10); int32_t port = strtol(argv[3], nullptr, 10);
@ -53,7 +54,7 @@ ds::Status StartServer(int argc, char **argv) {
// is called. This is a standard procedure for daemonize a process on unix. // is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) { if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
// A message queue for communication between parent and child (if we fork). // A message queue for communication between parent and child (if we fork).
@ -80,13 +81,13 @@ ds::Status StartServer(int argc, char **argv) {
// failed to fork // failed to fork
if (pid < 0) { if (pid < 0) {
std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno); std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else if (pid > 0) { } else if (pid > 0) {
// Parent and will be responsible for remove the queue on exit. // Parent and will be responsible for remove the queue on exit.
msg.RemoveResourcesOnExit(); msg.RemoveResourcesOnExit();
// Sleep one second and we attach to the msg que // Sleep one second and we attach to the msg que
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
ds::Status child_rc; ms::Status child_rc;
rc = msg.ReceiveStatus(&child_rc); rc = msg.ReceiveStatus(&child_rc);
if (rc.IsError()) { if (rc.IsError()) {
return rc; return rc;
@ -101,7 +102,7 @@ ds::Status StartServer(int argc, char **argv) {
"logs (under " "logs (under "
<< ds::DefaultLogDir() << ") for any issues that may happen after startup\n"; << ds::DefaultLogDir() << ") for any issues that may happen after startup\n";
signal(SIGCHLD, SIG_IGN); // ignore sig child signal. signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return ds::Status::OK(); return ms::Status::OK();
} else { } else {
// Child process will continue from here if daemonize and parent has already exited. // Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run. // If we are running in the foreground, none of the code in block below will be run.
@ -110,7 +111,7 @@ ds::Status StartServer(int argc, char **argv) {
sid = setsid(); sid = setsid();
if (sid < 0) { if (sid < 0) {
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno); std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
close(0); close(0);
close(1); close(1);
@ -137,10 +138,10 @@ ds::Status StartServer(int argc, char **argv) {
int main(int argc, char **argv) { int main(int argc, char **argv) {
// This executable is not to be called directly, and should be invoked by cache_admin executable. // This executable is not to be called directly, and should be invoked by cache_admin executable.
ds::Status rc = StartServer(argc, argv); ms::Status rc = StartServer(argc, argv);
// Check result // Check result
if (rc.IsError()) { if (rc.IsError()) {
auto errCode = rc.get_code(); auto errCode = rc.StatusCode();
auto errMsg = rc.ToString(); auto errMsg = rc.ToString();
std::cerr << errMsg << std::endl; std::cerr << errMsg << std::endl;
return static_cast<int>(errCode); return static_cast<int>(errCode);

View File

@ -136,7 +136,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
if (rc.IsOk()) { if (rc.IsOk()) {
*p = ptr; *p = ptr;
break; break;
} else if (rc.IsOutofMemory()) { } else if (rc == StatusCode::kMDOutOfMemory) {
inx = (inx + 1) % num_slots; inx = (inx + 1) % num_slots;
} else { } else {
return rc; return rc;
@ -162,7 +162,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
if (rc.IsOk()) { if (rc.IsOk()) {
*p = ptr; *p = ptr;
break; break;
} else if (rc.IsOutofMemory()) { } else if (rc == StatusCode::kMDOutOfMemory) {
// Make the next arena and continue. // Make the next arena and continue.
slot = (slot + 1) % num_segments; slot = (slot + 1) % num_segments;
} else { } else {
@ -172,7 +172,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
} }
// Handle the case we have done one round robin search. // Handle the case we have done one round robin search.
if (ptr == nullptr) { if (ptr == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
return rc; return rc;
} }

View File

@ -108,7 +108,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
bl.ptr = nullptr; bl.ptr = nullptr;
return rc; return rc;
} }
} else if (rc.IsOutofMemory()) { } else if (rc == StatusCode::kMDOutOfMemory) {
// If no memory, write to disk. // If no memory, write to disk.
if (sm_ != nullptr) { if (sm_ != nullptr) {
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
@ -116,7 +116,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
} else { } else {
// If asked to spill to disk instead but there is no storage set up, simply return no memory // If asked to spill to disk instead but there is no storage set up, simply return no memory
// instead. // instead.
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data"); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data");
} }
} else { } else {
return rc; return rc;
@ -125,7 +125,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
try { try {
rc = tree_->DoInsert(key, bl); rc = tree_->DoInsert(key, bl);
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); rc = Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
// Duplicate key is treated as error and we will also free the memory. // Duplicate key is treated as error and we will also free the memory.
if (rc.IsError() && bl.ptr != nullptr) { if (rc.IsError() && bl.ptr != nullptr) {

View File

@ -223,7 +223,7 @@ Status CreateCacheRequest::Prepare() {
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK(); return Status::OK();
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
} }
@ -277,7 +277,7 @@ Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK(); return Status::OK();
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
} }

View File

@ -169,7 +169,7 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
max_avail -= mem_consumed; max_avail -= mem_consumed;
if (max_avail <= 0) { if (max_avail <= 0) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} }
++it; ++it;
} }
@ -179,12 +179,12 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
if (max_avail < avail_mem) { if (max_avail < avail_mem) {
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
if (req_mem > max_avail) { if (req_mem > max_avail) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} else if (req_mem == 0) { } else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request. // 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) { if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} }
} }
} }
@ -249,7 +249,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
client_id = cs->num_clients_.fetch_add(1); client_id = cs->num_clients_.fetch_add(1);
all_caches_.emplace(connection_id, std::move(cs)); all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kMDOutOfMemory);
} }
} }
@ -276,7 +276,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK. // treat it as OK.
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
} }
Status CacheServer::DestroyCache(CacheRequest *rq) { Status CacheServer::DestroyCache(CacheRequest *rq) {
@ -306,7 +306,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
auto sz = rq->buf_data_size(); auto sz = rq->buf_data_size();
std::vector<const void *> buffers; std::vector<const void *> buffers;
@ -326,7 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
reply->set_result(std::to_string(id)); reply->set_result(std::to_string(id));
} else { } else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
} }
} }
return Status::OK(); return Status::OK();
@ -353,7 +353,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
Status rc; Status rc;
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
// Only if the cookie matches, we can accept insert into this cache that has a build phase // Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || cookie == cs->cookie()) { if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
@ -365,11 +365,11 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
} else { } else {
auto state = cs->GetState(); auto state = cs->GetState();
if (state != CacheServiceState::kFetchPhase) { if (state != CacheServiceState::kFetchPhase) {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Cache service is not in fetch phase. The current phase is " + "Cache service is not in fetch phase. The current phase is " +
std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id)); std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id));
} else { } else {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Cookie mismatch. Client id: " + std::to_string(client_id)); "Cookie mismatch. Client id: " + std::to_string(client_id));
} }
} }
@ -413,7 +413,7 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
Status rc; Status rc;
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data())); rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
// This is an internal request and is not tied to rpc. But need to post because there // This is an internal request and is not tied to rpc. But need to post because there
@ -494,7 +494,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id"); CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
auto &row_id_buf = rq->buf_data(0); auto &row_id_buf = rq->buf_data(0);
@ -551,7 +551,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
mem.resize(mem_sz); mem.resize(mem_sz);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error"); CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
} catch (const std::bad_alloc &e) { } catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kMDOutOfMemory);
} }
WritableSlice dest(mem.data(), mem_sz); WritableSlice dest(mem.data(), mem_sz);
RETURN_IF_NOT_OK(BatchFetch(fbb, &dest)); RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
@ -568,7 +568,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
CacheService::ServiceStat svc_stat; CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
@ -595,7 +595,7 @@ Status CacheServer::CacheSchema(CacheRequest *rq) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information"); CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
auto &create_schema_buf = rq->buf_data(0); auto &create_schema_buf = rq->buf_data(0);
@ -611,7 +611,7 @@ Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
// We are going to use std::string to allocate and hold the result which will be eventually // We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
@ -630,7 +630,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
// First piece of data is the cookie // First piece of data is the cookie
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
@ -639,7 +639,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
if (cookie == cs->cookie()) { if (cookie == cs->cookie()) {
RETURN_IF_NOT_OK(cs->BuildPhaseDone()); RETURN_IF_NOT_OK(cs->BuildPhaseDone());
} else { } else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
} }
} }
return Status::OK(); return Status::OK();
@ -652,7 +652,7 @@ Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
std::vector<row_id_type> gap; std::vector<row_id_type> gap;
RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap)); RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
@ -680,7 +680,7 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
// First piece of data is the on/off flag // First piece of data is the on/off flag
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag"); CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
@ -747,7 +747,7 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
auto client_id = rq->client_id(); auto client_id = rq->client_id();
MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects"; MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
@ -836,7 +836,7 @@ Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *inter
default: default:
std::string errMsg("Internal error, request type is not row request: "); std::string errMsg("Internal error, request type is not row request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
return Status::OK(); return Status::OK();
} }
@ -860,7 +860,7 @@ Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
default: default:
std::string errMsg("Internal error, request type is not session request: "); std::string errMsg("Internal error, request type is not session request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
return Status::OK(); return Status::OK();
} }
@ -931,7 +931,7 @@ Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
default: default:
std::string errMsg("Internal error, request type is not admin request: "); std::string errMsg("Internal error, request type is not admin request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
return Status::OK(); return Status::OK();
} }
@ -949,7 +949,7 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
} else { } else {
std::string errMsg("Unknown request type : "); std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} }
// Notify it is done, and move on to the next request. // Notify it is done, and move on to the next request.
@ -1045,7 +1045,7 @@ Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
RETURN_UNEXPECTED_IF_NULL(q); RETURN_UNEXPECTED_IF_NULL(q);
auto *p = new (std::nothrow) CacheServerRequest(); auto *p = new (std::nothrow) CacheServerRequest();
if (p == nullptr) { if (p == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
} }
*q = p; *q = p;
return Status::OK(); return Status::OK();
@ -1091,7 +1091,7 @@ Status CacheServer::DestroySession(CacheRequest *rq) {
} else { } else {
std::string errMsg = std::string errMsg =
"Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + "."; "Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
return Status(StatusCode::kFileNotExist, errMsg); return Status(StatusCode::kMDFileNotExist, errMsg);
} }
} }
} }
@ -1148,7 +1148,7 @@ Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id); CacheService *cs = GetService(connection_id);
if (cs == nullptr) { if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
} else { } else {
auto state = cs->GetState(); auto state = cs->GetState();
reply->set_result(std::to_string(static_cast<int8_t>(state))); reply->set_result(std::to_string(static_cast<int8_t>(state)));
@ -1247,7 +1247,7 @@ Status CacheServer::Builder::IpcResourceCleanup() {
std::string errMsg = "Cache server is already up and running"; std::string errMsg = "Cache server is already up and running";
// We return a duplicate error. The main() will intercept // We return a duplicate error. The main() will intercept
// and output a proper message // and output a proper message
return Status(StatusCode::kDuplicateKey, errMsg); return Status(StatusCode::kMDDuplicateKey, errMsg);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -419,7 +419,7 @@ class CacheServer : public Service {
Status GetRc() { Status GetRc() {
Status rc; Status rc;
for (auto &cache_rc : rc_lists_) { for (auto &cache_rc : rc_lists_) {
if (cache_rc.IsError() && !cache_rc.IsInterrupted() && rc.IsOk()) { if (cache_rc.IsError() && cache_rc != StatusCode::kMDInterrupted && rc.IsOk()) {
rc = cache_rc; rc = cache_rc;
} }
} }

View File

@ -42,7 +42,7 @@ Status CacheService::DoServiceStart() {
// Return an error if we use more than recommended memory. // Return an error if we use more than recommended memory.
std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) + std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
" while available system memory " + std::to_string(avail_mem); " while available system memory " + std::to_string(avail_mem);
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, errMsg); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, errMsg);
} }
memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem; memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
} }
@ -79,7 +79,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
if (st_ == CacheServiceState::kNoLocking) { if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just // We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on. // return out of memory from now on.
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kMDOutOfMemory);
} }
try { try {
// The first buffer is a flatbuffer which describes the rest of the buffers follow // The first buffer is a flatbuffer which describes the rest of the buffers follow
@ -119,16 +119,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
} }
// Now we cache the buffer. // Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, all_data); Status rc = cp_->Insert(*row_id_generated, all_data);
if (rc == Status(StatusCode::kDuplicateKey)) { if (rc == Status(StatusCode::kMDDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key."; MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else { } else {
if (HasBuildPhase()) { if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state // For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can // so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache. // do to resume other than to drop the cache.
if (rc.IsNoSpace()) { if (rc == StatusCode::kMDNoSpace) {
st_ = CacheServiceState::kNoSpace; st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) { } else if (rc == StatusCode::kMDOutOfMemory) {
st_ = CacheServiceState::kOutOfMemory; st_ = CacheServiceState::kOutOfMemory;
} }
} }
@ -152,7 +152,7 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
if (st_ == CacheServiceState::kNoLocking) { if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just // We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on. // return out of memory from now on.
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kMDOutOfMemory);
} }
try { try {
// If we don't need to generate id, we need to find it from the buffer. // If we don't need to generate id, we need to find it from the buffer.
@ -172,16 +172,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
} }
// Now we cache the buffer. // Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, {src}); Status rc = cp_->Insert(*row_id_generated, {src});
if (rc == Status(StatusCode::kDuplicateKey)) { if (rc == Status(StatusCode::kMDDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key."; MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else { } else {
if (HasBuildPhase()) { if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state // For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can // so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache. // do to resume other than to drop the cache.
if (rc.IsNoSpace()) { if (rc == StatusCode::kMDNoSpace) {
st_ = CacheServiceState::kNoSpace; st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) { } else if (rc == StatusCode::kMDOutOfMemory) {
st_ = CacheServiceState::kOutOfMemory; st_ = CacheServiceState::kOutOfMemory;
} }
} }
@ -307,7 +307,7 @@ Status CacheService::FetchSchema(std::string *out) const {
if (!mem.empty()) { if (!mem.empty()) {
*out = std::move(mem); *out = std::move(mem);
} else { } else {
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); return Status(StatusCode::kMDFileNotExist, __LINE__, __FILE__, "No schema has been cached");
} }
return Status::OK(); return Status::OK();
} }

View File

@ -36,7 +36,7 @@ Status CachePerfMsg::Receive(int32_t qID) {
auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR); auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR);
if (err == -1) { if (err == -1) {
if (errno == EIDRM) { if (errno == EIDRM) {
return Status(StatusCode::kInterrupted); return Status(StatusCode::kMDInterrupted);
} else { } else {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg); RETURN_STATUS_UNEXPECTED(errMsg);

View File

@ -33,7 +33,7 @@ int main(int argc, char **argv) {
if (rc.IsError()) { if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl; std::cerr << rc.ToString() << std::endl;
} }
return static_cast<int>(rc.get_code()); return static_cast<int>(rc.StatusCode());
} }
return 0; return 0;
} }

View File

@ -100,5 +100,7 @@ class CachePerfRun {
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
// todo: waiting for the master of the codes to refactor
#define get_code StatusCode
#define kDuplicateKey kMDDuplicateKey
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_

View File

@ -33,12 +33,12 @@ int main(int argc, char **argv) {
// If we hit any error, send the rc back to the parent. // If we hit any error, send the rc back to the parent.
if (rc.IsError()) { if (rc.IsError()) {
ds::ErrorMsg proto; ds::ErrorMsg proto;
proto.set_rc(static_cast<int32_t>(rc.get_code())); proto.set_rc(static_cast<int32_t>(rc.StatusCode()));
proto.set_msg(rc.ToString()); proto.set_msg(rc.ToString());
ds::CachePerfMsg msg; ds::CachePerfMsg msg;
(void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto); (void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto);
} }
return static_cast<int>(rc.get_code()); return static_cast<int>(rc.StatusCode());
} }
return 0; return 0;
} }

View File

@ -115,5 +115,9 @@ class CachePipelineRun {
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
// todo: waiting for the master of the codes to refactor
#define get_code StatusCode
#define kDuplicateKey kMDDuplicateKey
#define IsOutofMemory() StatusCode() == StatusCode::kMDOutOfMemory
#define IsNoSpace() StatusCode() == StatusCode::kMDNoSpace
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_

View File

@ -104,7 +104,7 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const
if (r_sz != sz) { if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno; errno_t err = (r_sz == 0) ? EOF : errno;
if (errno == ENOSPC) { if (errno == ENOSPC) {
return Status(StatusCode::kNoSpace, __LINE__, __FILE__); return Status(StatusCode::kMDNoSpace, __LINE__, __FILE__);
} else { } else {
RETURN_STATUS_UNEXPECTED(strerror(err)); RETURN_STATUS_UNEXPECTED(strerror(err));
} }
@ -157,7 +157,7 @@ Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer
Status rc; Status rc;
auto sc = new (std::nothrow) StorageContainer(path); auto sc = new (std::nothrow) StorageContainer(path);
if (sc == nullptr) { if (sc == nullptr) {
return Status(StatusCode::kOutOfMemory); return Status(StatusCode::kMDOutOfMemory);
} }
rc = sc->Create(); rc = sc->Create();
if (rc.IsOk()) { if (rc.IsOk()) {

View File

@ -96,9 +96,9 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
cont = containers_.at(num_containers - 1); cont = containers_.at(num_containers - 1);
off64_t offset; off64_t offset;
Status rc = cont->Insert(buf, &offset); Status rc = cont->Insert(buf, &offset);
if (rc.get_code() == StatusCode::kBuddySpaceFull) { if (rc.StatusCode() == StatusCode::kMDBuddySpaceFull) {
create_new_container = true; create_new_container = true;
// Remember how many containers we saw. In the next iteration we will do a comparision to see // Remember how many containers we saw. In the next iteration we will do a comparison to see
// if someone has already created it. // if someone has already created it.
last_num_container = num_containers; last_num_container = num_containers;
} else if (rc.IsOk()) { } else if (rc.IsOk()) {

View File

@ -140,7 +140,7 @@ Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *
// If we already had an unknown dimension, then we cannot have a second unknown dimension. // If we already had an unknown dimension, then we cannot have a second unknown dimension.
// We only support the compute of a single unknown dim. // We only support the compute of a single unknown dim.
if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) { if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Requested shape has more than one unknown dimension!"); "Requested shape has more than one unknown dimension!");
} }
@ -312,12 +312,12 @@ Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::strin
} }
// data type is mandatory field // data type is mandatory field
if (type_str.empty()) if (type_str.empty())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"json schema file for column " + col_name + " has invalid or missing column type."); "json schema file for column " + col_name + " has invalid or missing column type.");
// rank number is mandatory field // rank number is mandatory field
if (rank_value <= -1) if (rank_value <= -1)
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"json schema file for column " + col_name + " must define a positive rank value."); "json schema file for column " + col_name + " must define a positive rank value.");
// Create the column descriptor for this column from the data we pulled from the json file // Create the column descriptor for this column from the data we pulled from the json file
@ -425,7 +425,7 @@ Status DataSchema::AddColumn(const ColDescriptor &cd) {
Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) { Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
// Check if columns node exists. It is required for building schema from file. // Check if columns node exists. It is required for building schema from file.
if (js.find("columns") == js.end()) if (js.find("columns") == js.end())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"\"columns\" node is required in the schema json file."); "\"columns\" node is required in the schema json file.");
return Status::OK(); return Status::OK();
} }
@ -434,12 +434,12 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
// name to column index number. // name to column index number.
Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) { Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) {
if (out_column_name_map == nullptr) { if (out_column_name_map == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map.");
} }
for (int32_t i = 0; i < col_descs_.size(); ++i) { for (int32_t i = 0; i < col_descs_.size(); ++i) {
if (col_descs_[i].name().empty()) { if (col_descs_[i].name().empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Constructing column name map from schema, but found empty column name."); "Constructing column name map from schema, but found empty column name.");
} }
(*out_column_name_map)[col_descs_[i].name()] = i; (*out_column_name_map)[col_descs_[i].name()] = i;

View File

@ -290,7 +290,7 @@ Status ChildIterator::Drain() {
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
} }
if (curr_buffer_->eof()) { if (curr_buffer_->eof()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
} }
return Status::OK(); return Status::OK();
} }

View File

@ -122,7 +122,8 @@ Status BarrierOp::prepare(TensorQTable *const table) {
clean_up_ = false; clean_up_ = false;
buffer_id_ = 0; buffer_id_ = 0;
if (table == nullptr) { if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"BarrierOp prepare phase requires a tensor table.");
} }
// fill initial row // fill initial row
TensorRow new_row = {}; TensorRow new_row = {};
@ -150,7 +151,7 @@ Status BarrierOp::prepare(TensorQTable *const table) {
// fillBuffer always expects a new table to fill // fillBuffer always expects a new table to fill
Status BarrierOp::fillBuffer(TensorQTable *const table) { Status BarrierOp::fillBuffer(TensorQTable *const table) {
if (table == nullptr) { if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
} }
TensorRow new_row = {}; TensorRow new_row = {};
while (table->size() < static_cast<size_t>(rows_per_buffer_)) { while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
@ -172,7 +173,7 @@ Status BarrierOp::blockCond() {
{ {
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
} }
// we have condition name, however the flexibility is in python today // we have condition name, however the flexibility is in python today
try { try {
@ -180,11 +181,11 @@ Status BarrierOp::blockCond() {
py::object ret_py_obj = condition_function_(); py::object ret_py_obj = condition_function_();
// Process the return value // Process the return value
if (!py::isinstance<py::bool_>(ret_py_obj)) { if (!py::isinstance<py::bool_>(ret_py_obj)) {
return Status(StatusCode::kPyFuncException, return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, condition wait function should return true/false."); "Invalid parameter, condition wait function should return true/false.");
} }
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kMDPyFuncException, e.what());
} }
} }
return Status::OK(); return Status::OK();

View File

@ -61,7 +61,7 @@ Status BatchOp::Builder::SanityCheck() {
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " + err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n" std::to_string(builder_num_workers_) + ".\n"
: ""; : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
} }
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -261,7 +261,7 @@ Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatc
Status BatchOp::LaunchThreadsAndInitOp() { Status BatchOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) { if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
} }
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
@ -338,23 +338,23 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
// Acquire Python GIL // Acquire Python GIL
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized."); return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
} }
try { try {
py::object size = batch_size_func_(info); py::object size = batch_size_func_(info);
*batch_size = size.cast<int32_t>(); *batch_size = size.cast<int32_t>();
if (*batch_size <= 0) { if (*batch_size <= 0) {
return Status(StatusCode::kPyFuncException, return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch size function should return an integer greater than 0."); "Invalid parameter, batch size function should return an integer greater than 0.");
} }
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kMDPyFuncException, e.what());
} catch (const py::cast_error &e) { } catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException, return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch size function should return an integer greater than 0."); "Invalid parameter, batch size function should return an integer greater than 0.");
} }
} }
return Status(StatusCode::kOK, "Batch size func call succeed."); return Status(StatusCode::kSuccess, "Batch size func call succeed.");
} }
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) { Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
@ -362,7 +362,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
// Acquire Python GIL // Acquire Python GIL
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized."); return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
} }
try { try {
// Prepare batch map call back parameters // Prepare batch map call back parameters
@ -407,9 +407,9 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
output->push_back(std::move(output_batch)); output->push_back(std::move(output_batch));
} }
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kMDPyFuncException, e.what());
} catch (const py::cast_error &e) { } catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException, return Status(StatusCode::kMDPyFuncException,
"Invalid parameter, batch map function should return a tuple of list of numpy array."); "Invalid parameter, batch map function should return a tuple of list of numpy array.");
} }
} }

View File

@ -191,7 +191,7 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba
if (bucket_index + 1 >= bucket_boundaries_.size()) { if (bucket_index + 1 >= bucket_boundaries_.size()) {
std::string error_message = std::string error_message =
"Invalid data, requested to pad to bucket boundary, element falls in last bucket."; "Invalid data, requested to pad to bucket boundary, element falls in last bucket.";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, error_message);
} }
pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1;

View File

@ -42,7 +42,7 @@ BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePie
Status BuildSentencePieceVocabOp::operator()() { Status BuildSentencePieceVocabOp::operator()() {
if (tree_ == nullptr) { if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
} }
RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask( RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
@ -84,10 +84,10 @@ Status BuildSentencePieceVocabOp::SentenceThread() {
sentencepiece::util::Status s_status = sentencepiece::util::Status s_status =
sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto); sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
if (!s_status.ok()) { if (!s_status.ok()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message()); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, s_status.message());
} else { } else {
if (vocab_ == nullptr) { if (vocab_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, sentencepiece vocab not set."); "Invalid parameter, sentencepiece vocab not set.");
} }
vocab_->set_model_proto(model_proto); vocab_->set_model_proto(model_proto);
@ -145,7 +145,7 @@ void BuildSentencePieceVocabOp::Next(std::string *sentence) {
if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) { if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
ret_status_ = ret_status_ =
Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " + "Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " +
new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank())); new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank()));
read_done_ = true; read_done_ = true;

View File

@ -80,7 +80,7 @@ Status BuildVocabOp::WorkerEntry(int32_t worker_id) {
Status BuildVocabOp::operator()() { Status BuildVocabOp::operator()() {
// launch the collector thread // launch the collector thread
if (tree_ == nullptr) { if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
} }
RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks()));

View File

@ -233,7 +233,7 @@ Status CacheBase::UpdateColumnMapFromCache() {
// Get the schema from the server. It may not be there yet. So tolerate the error. // Get the schema from the server. It may not be there yet. So tolerate the error.
if (column_name_id_map_.empty()) { if (column_name_id_map_.empty()) {
rc = cache_client_->FetchSchema(&column_name_id_map_); rc = cache_client_->FetchSchema(&column_name_id_map_);
if (rc == Status(StatusCode::kFileNotExist)) { if (rc == Status(StatusCode::kMDFileNotExist)) {
MS_LOG(DEBUG) << "Schema not in the server yet."; MS_LOG(DEBUG) << "Schema not in the server yet.";
rc = Status::OK(); rc = Status::OK();
} }
@ -304,14 +304,14 @@ Status CacheBase::Prefetcher(int32_t worker_id) {
int32_t retry_count = 0; int32_t retry_count = 0;
do { do {
rc = PrefetchRows(prefetch_keys, &cache_miss); rc = PrefetchRows(prefetch_keys, &cache_miss);
if (rc.IsNetWorkError() && retry_count < max_retries) { if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
// If we get some network error, we will attempt some retries // If we get some network error, we will attempt some retries
retry_count++; retry_count++;
} else if (rc.IsError()) { } else if (rc.IsError()) {
MS_LOG(WARNING) << rc.ToString(); MS_LOG(WARNING) << rc.ToString();
return rc; return rc;
} }
} while (rc.IsNetWorkError()); } while (rc == StatusCode::kMDNetWorkError);
// In case any thread is waiting for the rows to come back and blocked on a semaphore, // In case any thread is waiting for the rows to come back and blocked on a semaphore,
// we will put an empty row in the local cache. // we will put an empty row in the local cache.
if (rc.IsError() && AllowCacheMiss()) { if (rc.IsError() && AllowCacheMiss()) {

View File

@ -39,12 +39,12 @@ CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_
// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const { Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) { if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr."); "Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
} }
// Make sure the cache client has a valid session // Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) { if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0."); "Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
} }
return Status::OK(); return Status::OK();
@ -59,7 +59,7 @@ Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
} }
Status CacheLookupOp::operator()() { Status CacheLookupOp::operator()() {
if (!sampler_) { if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheLookupOp requires a sampler before it can be executed, but got nullptr."); "Invalid parameter, CacheLookupOp requires a sampler before it can be executed, but got nullptr.");
} }
RETURN_IF_NOT_OK(RegisterResources()); RETURN_IF_NOT_OK(RegisterResources());

View File

@ -129,7 +129,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
Status rc; Status rc;
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) { if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
cache_missing_rows_ = false; cache_missing_rows_ = false;
if (rc.IsOutofMemory() || rc.IsNoSpace()) { if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
cache_client_->ServerRunningOutOfResources(); cache_client_->ServerRunningOutOfResources();
} else { } else {
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString(); MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
@ -156,7 +156,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
rc = rq->AsyncSendCacheRequest(cache_client_, row); rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) { if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) { } else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
cache_missing_rows_ = false; cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources(); cache_client_->ServerRunningOutOfResources();
} }
@ -188,9 +188,9 @@ Status CacheMergeOp::Cleaner() {
Status rc = rq->CheckCacheResult(); Status rc = rq->CheckCacheResult();
if (rc.IsError()) { if (rc.IsError()) {
// If interrupt, time to quit. // If interrupt, time to quit.
if (rc.IsInterrupted()) { if (rc == StatusCode::kMDInterrupted) {
return Status::OK(); return Status::OK();
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) { } else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
// The server is hitting some limit and we will turn off caching from now on. // The server is hitting some limit and we will turn off caching from now on.
cache_missing_rows_ = false; cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources(); cache_client_->ServerRunningOutOfResources();
@ -215,7 +215,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
// Construct the cache // Construct the cache
const bool generate_ids = false; const bool generate_ids = false;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids); Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) { if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
// We are told the cache has been created already. // We are told the cache has been created already.
MS_LOG(INFO) << "Cache created already"; MS_LOG(INFO) << "Cache created already";
rc = Status::OK(); rc = Status::OK();
@ -244,12 +244,12 @@ CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(
// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
Status CacheMergeOp::Builder::SanityCheck() const { Status CacheMergeOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) { if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr."); "Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
} }
// Make sure the cache client has a valid session // Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) { if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0."); "Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
} }
return Status::OK(); return Status::OK();
@ -316,7 +316,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory // We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc; Status rc;
rc = cc->AsyncWriteRow(row); rc = cc->AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) { if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get()); cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) { if (rc.IsOk()) {

View File

@ -41,12 +41,12 @@ CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullp
// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const { Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) { if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr."); "Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
} }
// Make sure the cache client has a valid session // Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) { if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0."); "Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
} }
return Status::OK(); return Status::OK();
@ -78,7 +78,7 @@ Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work // This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() { Status CacheOp::operator()() {
if (!sampler_) { if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr."); "Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
} }
RETURN_IF_NOT_OK(RegisterResources()); RETURN_IF_NOT_OK(RegisterResources());
@ -113,7 +113,7 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
Status rc; Status rc;
// Do the Async write if we attach to the shared memory. // Do the Async write if we attach to the shared memory.
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr)); rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
if (rc.get_code() == StatusCode::kNotImplementedYet) { if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else if (rc.IsError()) { } else if (rc.IsError()) {
return rc; return rc;
@ -169,9 +169,9 @@ Status CacheOp::WaitForCachingAllRows() {
BuildPhaseDone = true; BuildPhaseDone = true;
break; break;
case CacheServiceState::kOutOfMemory: case CacheServiceState::kOutOfMemory:
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory"); return Status(StatusCode::kMDOutOfMemory, "Cache server is running out of memory");
case CacheServiceState::kNoSpace: case CacheServiceState::kNoSpace:
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage"); return Status(StatusCode::kMDNoSpace, "Cache server is running of out spill storage");
case CacheServiceState::kNone: case CacheServiceState::kNone:
case CacheServiceState::kError: case CacheServiceState::kError:
default: default:
@ -246,7 +246,7 @@ Status CacheOp::PrepareNodePostAction() {
// Construct the cache // Construct the cache
const bool generate_ids = true; const bool generate_ids = true;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids); Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) { if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
// We are told the cache has been created already. So we skip the build phase. // We are told the cache has been created already. So we skip the build phase.
phase_ = Phase::kFetchPhase; phase_ = Phase::kFetchPhase;
rc = Status::OK(); rc = Status::OK();

View File

@ -157,18 +157,14 @@ Status DeviceQueueOp::SendDataToAscend() {
TensorRow currRow; TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
while (stop_send_ && ascend_keep_waiting_) { WaitContinueSignal();
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) { if (status == TdtStatus::FAILED) {
if (stop_send_) { if (stop_send_) {
MS_LOG(INFO) << "stop_send received"; MS_LOG(INFO) << "stop_send received";
return Status::OK(); return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
} }
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
} }
if (create_data_info_queue_) { if (create_data_info_queue_) {
DATA_INFO data_info; DATA_INFO data_info;
@ -200,9 +196,8 @@ Status DeviceQueueOp::SendDataToAscend() {
if (stop_send_) { if (stop_send_) {
MS_LOG(INFO) << "stop_send received"; MS_LOG(INFO) << "stop_send received";
return Status::OK(); return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
} }
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
} }
MS_LOG(INFO) << "an epoch has already sent, now stop send data."; MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
stop_send_ = true; stop_send_ = true;
@ -219,13 +214,19 @@ Status DeviceQueueOp::SendDataToAscend() {
return Status::OK(); return Status::OK();
} }
void DeviceQueueOp::WaitContinueSignal() const {
while (stop_send_ && ascend_keep_waiting_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
}
#endif #endif
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) { Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
if (!create_data_info_queue_) { if (!create_data_info_queue_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
} }
// This place has a race condition with operator(), so the first one // This place has a race condition with operator(), so the first one
// arrive here will do the initialize work. // arrive here will do the initialize work.
@ -241,7 +242,7 @@ Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
} }
#else #else
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) { Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
} }
#endif #endif
@ -301,7 +302,7 @@ Status DeviceQueueOp::PushDataToGPU() {
} }
handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, release_function); handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, release_function);
if (handle == INVALID_HANDLE) { if (handle == INVALID_HANDLE) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data.");
} }
is_open = true; is_open = true;
} }
@ -309,14 +310,14 @@ Status DeviceQueueOp::PushDataToGPU() {
// Data prefetch only when PS mode enables cache. // Data prefetch only when PS mode enables cache.
if (items.size() > 0) { if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
} }
} }
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) { if (ret) {
if (ret == BlockQueueStatus_T::ERROR_INPUT) { if (ret == BlockQueueStatus_T::ERROR_INPUT) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it.");
} else { } else {
if (!stop_send_) { if (!stop_send_) {
MS_LOG(DEBUG) << "Retry pushing data..."; MS_LOG(DEBUG) << "Retry pushing data...";
@ -438,13 +439,13 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
for (auto &sub_item : *items) { for (auto &sub_item : *items) {
RETURN_IF_NOT_OK(pool_[worker_id]->Allocate(sub_item.data_len_, &sub_item.data_ptr_)); RETURN_IF_NOT_OK(pool_[worker_id]->Allocate(sub_item.data_len_, &sub_item.data_ptr_));
if (sub_item.data_ptr_ == nullptr) { if (sub_item.data_ptr_ == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Memory malloc failed."); return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
} }
const unsigned char *column_data = curr_row[i]->GetBuffer(); const unsigned char *column_data = curr_row[i]->GetBuffer();
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) { static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {
MS_LOG(ERROR) << "memcpy_s failed!"; MS_LOG(ERROR) << "memcpy_s failed!";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "memcpy_s failed.");
} }
} }

View File

@ -190,6 +190,7 @@ class DeviceQueueOp : public PipelineOp {
private: private:
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
void WaitContinueSignal() const;
Status SendDataToAscend(); Status SendDataToAscend();
bool ascend_keep_waiting_; bool ascend_keep_waiting_;
#endif #endif

View File

@ -43,7 +43,7 @@ Status FilterOp::Builder::SanityCheck() {
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " + err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n" std::to_string(builder_num_workers_) + ".\n"
: ""; : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
} }
FilterOp::Builder::Builder() { FilterOp::Builder::Builder() {
@ -66,7 +66,7 @@ FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_wor
Status FilterOp::operator()() { Status FilterOp::operator()() {
// The operator class just starts off threads by calling the tree_ function. // The operator class just starts off threads by calling the tree_ function.
if (tree_ == nullptr) { if (tree_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
} }
filter_queues_.Init(num_workers_, oc_queue_size_); filter_queues_.Init(num_workers_, oc_queue_size_);
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
@ -244,7 +244,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
RETURN_IF_NOT_OK(predicate_func_->Compute(input, &output)); RETURN_IF_NOT_OK(predicate_func_->Compute(input, &output));
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_predicate, {})); RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_predicate, {}));
return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
} }
// Visitor accept method for NodePass // Visitor accept method for NodePass

Some files were not shown because too many files have changed in this diff Show More