cxx api refactor: tensor/status/model
This commit is contained in:
parent
f1009cb21b
commit
7d2fd6e76c
|
@ -65,7 +65,7 @@ install(
|
|||
|
||||
install(
|
||||
TARGETS mindspore_shared_lib
|
||||
LIBRARY DESTINATION ${INSTALL_LIB_DIR}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
|
@ -327,7 +327,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/transforms.h
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision.h
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/vision_lite.h
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/minddata_eager.h
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include/execute.h
|
||||
DESTINATION ${INSTALL_BASE_DIR}/include/minddata/dataset/include
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -109,6 +109,8 @@ if(PLATFORM_ARM64)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
|
||||
if(ENABLE_TOOLS)
|
||||
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
|
@ -128,6 +130,8 @@ elseif(PLATFORM_ARM32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
||||
if(ENABLE_TOOLS)
|
||||
install(TARGETS benchmark RUNTIME DESTINATION ${RUNTIME_PKG_NAME}/benchmark COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
|
@ -162,6 +166,8 @@ elseif(WIN32)
|
|||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
||||
set(WIN_LIB_DIR_RUN_X86 ${RUNTIME_PKG_NAME}/benchmark)
|
||||
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${WIN_LIB_DIR_RUN_X86}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -182,6 +188,8 @@ else()
|
|||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "include/api/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class InputAndOutput;
|
||||
using Input = InputAndOutput;
|
||||
using Output = InputAndOutput;
|
||||
|
@ -35,7 +34,7 @@ class MS_API CellBase {
|
|||
virtual ~CellBase() = default;
|
||||
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
|
||||
virtual std::shared_ptr<CellBase> Clone() const = 0;
|
||||
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; }
|
||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
|
||||
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
|
||||
};
|
||||
|
||||
|
@ -57,16 +56,16 @@ class MS_API ParameterCell final : public Cell<ParameterCell> {
|
|||
ParameterCell(ParameterCell &&);
|
||||
ParameterCell &operator=(ParameterCell &&);
|
||||
|
||||
explicit ParameterCell(const Tensor &);
|
||||
ParameterCell &operator=(const Tensor &);
|
||||
explicit ParameterCell(const MSTensor &);
|
||||
ParameterCell &operator=(const MSTensor &);
|
||||
|
||||
explicit ParameterCell(Tensor &&);
|
||||
ParameterCell &operator=(Tensor &&);
|
||||
explicit ParameterCell(MSTensor &&);
|
||||
ParameterCell &operator=(MSTensor &&);
|
||||
|
||||
Tensor GetTensor() const { return tensor_; }
|
||||
MSTensor GetTensor() const { return tensor_; }
|
||||
|
||||
private:
|
||||
Tensor tensor_;
|
||||
MSTensor tensor_;
|
||||
};
|
||||
|
||||
class MS_API OpCellBase : public CellBase {
|
||||
|
@ -99,11 +98,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
|
|||
explicit GraphCell(const std::shared_ptr<Graph> &);
|
||||
|
||||
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
private:
|
||||
friend class ModelImpl;
|
||||
|
@ -119,8 +116,8 @@ class MS_API InputAndOutput {
|
|||
~InputAndOutput() = default;
|
||||
|
||||
// no explicit
|
||||
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
|
||||
|
||||
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
|
||||
|
||||
|
@ -132,6 +129,5 @@ class MS_API InputAndOutput {
|
|||
std::vector<InputAndOutput> prev_;
|
||||
int32_t index_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CELL_H
|
||||
|
|
|
@ -16,26 +16,49 @@
|
|||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
#include <map>
|
||||
#include <any>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Context {
|
||||
public:
|
||||
static Context &Instance();
|
||||
const std::string &GetDeviceTarget() const;
|
||||
Context &SetDeviceTarget(const std::string &device_target);
|
||||
uint32_t GetDeviceID() const;
|
||||
Context &SetDeviceID(uint32_t device_id);
|
||||
constexpr auto kDeviceTypeAscend310 = "Ascend310";
|
||||
constexpr auto kDeviceTypeAscend910 = "Ascend910";
|
||||
|
||||
private:
|
||||
Context();
|
||||
~Context();
|
||||
class ContextImpl;
|
||||
std::shared_ptr<ContextImpl> impl_;
|
||||
struct MS_API Context {
|
||||
virtual ~Context() = default;
|
||||
std::map<std::string, std::any> params;
|
||||
};
|
||||
|
||||
struct MS_API GlobalContext : public Context {
|
||||
static std::shared_ptr<Context> GetGlobalContext();
|
||||
|
||||
static void SetGlobalDeviceTarget(const std::string &device_target);
|
||||
static std::string GetGlobalDeviceTarget();
|
||||
|
||||
static void SetGlobalDeviceID(const uint32_t &device_id);
|
||||
static uint32_t GetGlobalDeviceID();
|
||||
};
|
||||
|
||||
struct MS_API ModelContext : public Context {
|
||||
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
||||
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
|
||||
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
|
||||
static std::string GetInputShape(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
|
||||
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
|
||||
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
|
||||
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
||||
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
||||
|
||||
namespace mindspore {
|
||||
enum class DataType : int {
|
||||
kTypeUnknown = 0,
|
||||
kObjectTypeString = 12,
|
||||
kObjectTypeList = 13,
|
||||
kObjectTypeTuple = 14,
|
||||
kObjectTypeTensorType = 17,
|
||||
kNumberTypeBool = 30,
|
||||
kNumberTypeInt8 = 32,
|
||||
kNumberTypeInt16 = 33,
|
||||
kNumberTypeInt32 = 34,
|
||||
kNumberTypeInt64 = 35,
|
||||
kNumberTypeUInt8 = 37,
|
||||
kNumberTypeUInt16 = 38,
|
||||
kNumberTypeUInt32 = 39,
|
||||
kNumberTypeUInt64 = 40,
|
||||
kNumberTypeFloat16 = 42,
|
||||
kNumberTypeFloat32 = 43,
|
||||
kNumberTypeFloat64 = 44,
|
||||
kNumberTypeEnd = 46,
|
||||
// add new enum here
|
||||
kInvalidType = INT32_MAX,
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
#define MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
@ -24,21 +25,21 @@
|
|||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Graph {
|
||||
public:
|
||||
class GraphData;
|
||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||
explicit Graph(std::nullptr_t);
|
||||
~Graph();
|
||||
|
||||
enum ModelType ModelType() const;
|
||||
bool operator==(std::nullptr_t) const;
|
||||
|
||||
private:
|
||||
friend class GraphCell;
|
||||
friend class ModelImpl;
|
||||
std::shared_ptr<GraphData> graph_data_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <any>
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
|
||||
typedef enum : uint32_t {
|
||||
NO_BIND = 0, /**< no bind */
|
||||
HIGHER_CPU = 1, /**< bind higher cpu first */
|
||||
MID_CPU = 2 /**< bind middle cpu first */
|
||||
} CpuBindMode;
|
||||
|
||||
class Allocator;
|
||||
} // namespace lite
|
||||
|
||||
struct MS_API Context {
|
||||
public:
|
||||
static void Clear(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetAsDefault(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetVendorName(const std::shared_ptr<Context> &contxet, const std::string &name);
|
||||
static std::string GetVendorName(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetThreadNum(const std::shared_ptr<Context> &contxet, int num);
|
||||
static int GetThreadNum(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetAllocator(const std::shared_ptr<Context> &contxet, std::shared_ptr<lite::Allocator> alloc);
|
||||
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void ConfigCPU(const std::shared_ptr<Context> &contxet, bool config);
|
||||
static bool IfCPUEnabled(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void ConfigCPUFp16(const std::shared_ptr<Context> &contxet, bool config);
|
||||
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetCPUBindMode(const std::shared_ptr<Context> &contxet, lite::CpuBindMode mode);
|
||||
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void ConfigGPU(const std::shared_ptr<Context> &contxet, bool config);
|
||||
static bool IfGPUEnabled(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void ConfigGPUFp16(const std::shared_ptr<Context> &contxet, bool config);
|
||||
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void ConfigNPU(const std::shared_ptr<Context> &contxet, bool config);
|
||||
static bool IfNPUEnabled(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
static void SetNPUFrequency(const std::shared_ptr<Context> &contxet, int freq);
|
||||
static int GetNPUFrequency(const std::shared_ptr<Context> &contxet);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::any> context_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
|
@ -20,41 +20,36 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class ModelImpl;
|
||||
// todo: minddata c++ interface
|
||||
class DataSet {};
|
||||
struct Context;
|
||||
|
||||
class MS_API Model {
|
||||
public:
|
||||
explicit Model(const std::vector<Output> &network);
|
||||
explicit Model(const GraphCell &graph);
|
||||
explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
|
||||
explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
|
||||
~Model();
|
||||
Model(const Model &) = delete;
|
||||
void operator=(const Model &) = delete;
|
||||
|
||||
Status Build(const std::map<std::string, std::string> &options);
|
||||
Status Build();
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
|
||||
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
struct MS_API Conv2D : public OpCell<Conv2D> {
|
||||
Conv2D() : OpCell("Conv2D") {}
|
||||
~Conv2D() override = default;
|
||||
|
@ -45,6 +44,5 @@ struct MS_API Conv2D : public OpCell<Conv2D> {
|
|||
std::vector<int> dilation = {1, 1, 1, 1};
|
||||
int group = 1;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H
|
||||
|
|
|
@ -26,15 +26,14 @@
|
|||
#include "include/api/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
|
||||
static Graph LoadModel(const std::string &file, ModelType model_type);
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
|
|
@ -17,37 +17,129 @@
|
|||
#define MINDSPORE_INCLUDE_API_STATUS_H
|
||||
|
||||
#include <string>
|
||||
#include <ostream>
|
||||
#include <climits>
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
enum StatusCode {
|
||||
SUCCESS = 0,
|
||||
FAILED,
|
||||
INVALID_INPUTS,
|
||||
// insert new status code here
|
||||
UNKNOWN = 0xFFFFFFFF
|
||||
enum CompCode : uint32_t {
|
||||
kCore = 0x00000000u,
|
||||
kMD = 0x10000000u,
|
||||
kME = 0x20000000u,
|
||||
kMC = 0x30000000u,
|
||||
kLite = 0xF0000000u,
|
||||
};
|
||||
|
||||
enum StatusCode : uint32_t {
|
||||
kSuccess = 0,
|
||||
// Core
|
||||
kCoreFailed = kCore | 0x1,
|
||||
|
||||
// MD
|
||||
kMDOutOfMemory = kMD | 1,
|
||||
kMDShapeMisMatch = kMD | 2,
|
||||
kMDInterrupted = kMD | 3,
|
||||
kMDNoSpace = kMD | 4,
|
||||
kMDPyFuncException = kMD | 5,
|
||||
kMDDuplicateKey = kMD | 6,
|
||||
kMDPythonInterpreterFailure = kMD | 7,
|
||||
kMDTDTPushFailure = kMD | 8,
|
||||
kMDFileNotExist = kMD | 9,
|
||||
kMDProfilingError = kMD | 10,
|
||||
kMDBoundingBoxOutOfBounds = kMD | 11,
|
||||
kMDBoundingBoxInvalidShape = kMD | 12,
|
||||
kMDSyntaxError = kMD | 13,
|
||||
kMDTimeOut = kMD | 14,
|
||||
kMDBuddySpaceFull = kMD | 15,
|
||||
kMDNetWorkError = kMD | 16,
|
||||
kMDNotImplementedYet = kMD | 17,
|
||||
// Make this error code the last one. Add new error code above it.
|
||||
kMDUnexpectedError = kMD | 127,
|
||||
|
||||
// ME
|
||||
kMEFailed = kME | 0x1,
|
||||
kMEInvalidInput = kME | 0x2,
|
||||
|
||||
// MC
|
||||
kMCFailed = kMC | 0x1,
|
||||
kMCDeviceError = kMC | 0x2,
|
||||
kMCInvalidInput = kMC | 0x3,
|
||||
kMCInvalidArgs = kMC | 0x4,
|
||||
|
||||
// Lite // Common error code, range: [-1, -100)
|
||||
kLiteError = kLite | (0x0FFFFFFF & -1), /**< Common error code. */
|
||||
kLiteNullptr = kLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/
|
||||
kLiteParamInvalid = kLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/
|
||||
kLiteNoChange = kLite | (0x0FFFFFFF & -4), /**< No change. */
|
||||
kLiteSuccessExit = kLite | (0x0FFFFFFF & -5), /**< No error but exit. */
|
||||
kLiteMemoryFailed = kLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
|
||||
kLiteNotSupport = kLite | (0x0FFFFFFF & -7), /**< Fail to support. */
|
||||
kLiteThreadPoolError = kLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
|
||||
|
||||
// Executor error code, range: [-100,-200)
|
||||
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
||||
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
|
||||
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
|
||||
|
||||
// Graph error code, range: [-200,-300)
|
||||
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
|
||||
|
||||
// Node error code, range: [-300,-400)
|
||||
kLiteNotFindOp = kLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */
|
||||
kLiteInvalidOpName = kLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */
|
||||
kLiteInvalidOpAttr = kLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */
|
||||
kLiteOpExecuteFailure = kLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */
|
||||
|
||||
// Tensor error code, range: [-400,-500)
|
||||
kLiteFormatError = kLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */
|
||||
|
||||
// InferShape error code, range: [-500,-600)
|
||||
kLiteInferError = kLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */
|
||||
kLiteInferInvalid = kLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */
|
||||
|
||||
// User input param error code, range: [-600, 700)
|
||||
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
|
||||
};
|
||||
|
||||
class Status {
|
||||
public:
|
||||
Status() : status_code_(FAILED) {}
|
||||
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
|
||||
: status_code_(status_code), status_msg_(status_msg) {}
|
||||
Status() : status_code_(kSuccess), line_of_code_(-1) {}
|
||||
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
|
||||
: status_code_(status_code), status_msg_(status_msg), line_of_code_(-1) {}
|
||||
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
|
||||
|
||||
~Status() = default;
|
||||
|
||||
bool IsSuccess() const { return status_code_ == SUCCESS; }
|
||||
enum StatusCode StatusCode() const { return status_code_; }
|
||||
std::string StatusMessage() const { return status_msg_; }
|
||||
const std::string &ToString() const { return status_msg_; }
|
||||
|
||||
int GetLineOfCode() const { return line_of_code_; }
|
||||
const std::string &GetErrDescription() const { return status_msg_; }
|
||||
const std::string &SetErrDescription(const std::string &err_description);
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const Status &s);
|
||||
|
||||
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
|
||||
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
|
||||
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
|
||||
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
|
||||
operator bool() const = delete;
|
||||
|
||||
explicit operator bool() const { return (status_code_ == kSuccess); }
|
||||
explicit operator int() const { return static_cast<int>(status_code_); }
|
||||
|
||||
static Status OK() { return Status(StatusCode::kSuccess); }
|
||||
|
||||
bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
|
||||
|
||||
bool IsError() const { return !IsOk(); }
|
||||
|
||||
static std::string CodeAsString(enum StatusCode c);
|
||||
|
||||
private:
|
||||
enum StatusCode status_code_;
|
||||
std::string status_msg_;
|
||||
int line_of_code_;
|
||||
std::string file_name_;
|
||||
std::string err_description_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_STATUS_H
|
||||
|
|
|
@ -16,15 +16,20 @@
|
|||
#ifndef MINDSPORE_INCLUDE_API_TYPES_H
|
||||
#define MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/data_type.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define MS_API __declspec(dllexport)
|
||||
#else
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
enum ModelType {
|
||||
enum ModelType : uint32_t {
|
||||
kMindIR = 0,
|
||||
kAIR = 1,
|
||||
kOM = 2,
|
||||
|
@ -33,52 +38,38 @@ enum ModelType {
|
|||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
enum DataType {
|
||||
kMsUnknown = 0,
|
||||
kMsBool = 1,
|
||||
kMsInt8 = 2,
|
||||
kMsInt16 = 3,
|
||||
kMsInt32 = 4,
|
||||
kMsInt64 = 5,
|
||||
kMsUint8 = 6,
|
||||
kMsUint16 = 7,
|
||||
kMsUint32 = 8,
|
||||
kMsUint64 = 9,
|
||||
kMsFloat16 = 10,
|
||||
kMsFloat32 = 11,
|
||||
kMsFloat64 = 12,
|
||||
// insert new data type here
|
||||
kInvalidDataType = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
class MS_API Tensor {
|
||||
class MS_API MSTensor {
|
||||
public:
|
||||
Tensor();
|
||||
Tensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, size_t data_len);
|
||||
~Tensor();
|
||||
class Impl;
|
||||
|
||||
static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
|
||||
MSTensor();
|
||||
explicit MSTensor(const std::shared_ptr<Impl> &impl);
|
||||
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
~MSTensor();
|
||||
|
||||
const std::string &Name() const;
|
||||
void SetName(const std::string &name);
|
||||
|
||||
api::DataType DataType() const;
|
||||
void SetDataType(api::DataType type);
|
||||
|
||||
enum DataType DataType() const;
|
||||
const std::vector<int64_t> &Shape() const;
|
||||
void SetShape(const std::vector<int64_t> &shape);
|
||||
int64_t ElementNum() const;
|
||||
|
||||
const void *Data() const;
|
||||
std::shared_ptr<const void> Data() const;
|
||||
void *MutableData();
|
||||
size_t DataSize() const;
|
||||
|
||||
bool ResizeData(size_t data_len);
|
||||
bool SetData(const void *data, size_t data_len);
|
||||
bool IsDevice() const;
|
||||
|
||||
int64_t ElementNum() const;
|
||||
static int GetTypeSize(api::DataType type);
|
||||
Tensor Clone() const;
|
||||
MSTensor Clone() const;
|
||||
bool operator==(std::nullptr_t) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
friend class ModelImpl;
|
||||
explicit MSTensor(std::nullptr_t);
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
|
@ -101,21 +92,5 @@ class MS_API Buffer {
|
|||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
extern MS_API const char *kDeviceTypeAscend310;
|
||||
extern MS_API const char *kDeviceTypeAscend910;
|
||||
extern MS_API const char *kDeviceTypeGpu;
|
||||
|
||||
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
|
||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
||||
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
|
||||
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
|
||||
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
|
||||
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
|
||||
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
|
||||
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
|
||||
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
|
||||
// "high_precision" or "high_performance", default as "high_performance"
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
|
|
@ -23,7 +23,7 @@ if(ENABLE_D)
|
|||
endif()
|
||||
|
||||
if(ENABLE_GPU)
|
||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc")
|
||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc")
|
||||
endif()
|
||||
|
||||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||
|
@ -45,8 +45,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf)
|
||||
else()
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
if(ENABLE_D OR ENABLE_ACL)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
|
||||
else()
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
mindspore proto_input mindspore_gvar mindspore::protobuf)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
|
||||
|
||||
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
|
||||
|
@ -40,23 +40,23 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {}
|
||||
ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
|
||||
|
||||
ParameterCell &ParameterCell::operator=(const Tensor &tensor) {
|
||||
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
|
||||
tensor_ = tensor.Clone();
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {}
|
||||
ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
|
||||
|
||||
ParameterCell &ParameterCell::operator=(Tensor &&tensor) {
|
||||
ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
|
||||
tensor_ = tensor;
|
||||
return *this;
|
||||
}
|
||||
|
||||
GraphCell::GraphCell(const Graph &graph)
|
||||
: graph_(std::make_shared<Graph>(graph)),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
|
@ -64,7 +64,7 @@ GraphCell::GraphCell(const Graph &graph)
|
|||
|
||||
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
|
||||
: graph_(graph),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
|
@ -72,13 +72,13 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
|
|||
|
||||
GraphCell::GraphCell(Graph &&graph)
|
||||
: graph_(std::make_shared<Graph>(graph)),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
}
|
||||
|
||||
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->Run(inputs, outputs);
|
||||
}
|
||||
|
@ -88,25 +88,24 @@ Status GraphCell::Load() {
|
|||
return executor_->Load();
|
||||
}
|
||||
|
||||
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> GraphCell::GetInputs() {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return executor_->GetInputs();
|
||||
}
|
||||
|
||||
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> GraphCell::GetOutputs() {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return executor_->GetOutputs();
|
||||
}
|
||||
|
||||
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
|
||||
|
||||
InputAndOutput::InputAndOutput(const Tensor &tensor)
|
||||
InputAndOutput::InputAndOutput(const MSTensor &tensor)
|
||||
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
|
||||
InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
|
||||
InputAndOutput::InputAndOutput(MSTensor &&tensor)
|
||||
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
|
||||
|
||||
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
|
||||
int32_t index)
|
||||
: cell_(cell), prev_(prev), index_(index) {}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,49 +16,119 @@
|
|||
#include "include/api/context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class Context::ContextImpl {
|
||||
public:
|
||||
ContextImpl() : device_target_("NotSet"), device_id_(0) {}
|
||||
~ContextImpl() = default;
|
||||
const std::string &GetDeviceTarget() const { return device_target_; }
|
||||
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; }
|
||||
uint32_t GetDeviceID() const { return device_id_; }
|
||||
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
|
||||
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
|
||||
constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
|
||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
||||
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
|
||||
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
|
||||
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
|
||||
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
|
||||
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
|
||||
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
|
||||
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
|
||||
|
||||
private:
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
namespace mindspore {
|
||||
template <class T>
|
||||
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
||||
auto iter = context->params.find(key);
|
||||
if (iter == context->params.end()) {
|
||||
return T();
|
||||
}
|
||||
const std::any &value = iter->second;
|
||||
if (value.type() != typeid(T)) {
|
||||
return T();
|
||||
}
|
||||
|
||||
Context &Context::Instance() {
|
||||
static Context context;
|
||||
return context;
|
||||
return std::any_cast<T>(value);
|
||||
}
|
||||
|
||||
const std::string &Context::GetDeviceTarget() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceTarget();
|
||||
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
||||
static std::shared_ptr<Context> g_context = std::make_shared<Context>();
|
||||
return g_context;
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceTarget(const std::string &device_target) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceTarget(device_target);
|
||||
return *this;
|
||||
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
global_context->params[kGlobalContextDeviceTarget] = device_target;
|
||||
}
|
||||
|
||||
uint32_t Context::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceID();
|
||||
std::string GlobalContext::GetGlobalDeviceTarget() {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceID(device_id);
|
||||
return *this;
|
||||
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
global_context->params[kGlobalContextDeviceID] = device_id;
|
||||
}
|
||||
|
||||
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); }
|
||||
uint32_t GlobalContext::GetGlobalDeviceID() {
|
||||
auto global_context = GetGlobalContext();
|
||||
MS_EXCEPTION_IF_NULL(global_context);
|
||||
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
|
||||
}
|
||||
|
||||
Context::~Context() {}
|
||||
} // namespace mindspore::api
|
||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
||||
}
|
||||
|
||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInputFormat] = format;
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInputFormat);
|
||||
}
|
||||
|
||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionInputShape] = shape;
|
||||
}
|
||||
|
||||
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionInputShape);
|
||||
}
|
||||
|
||||
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionOutputType] = output_type;
|
||||
}
|
||||
|
||||
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<enum DataType>(context, kModelOptionOutputType);
|
||||
}
|
||||
|
||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionPrecisionMode] = precision_mode;
|
||||
}
|
||||
|
||||
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionPrecisionMode);
|
||||
}
|
||||
|
||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
|
||||
}
|
||||
|
||||
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <utility>
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
template <class T>
|
||||
class Factory {
|
||||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
@ -79,5 +79,5 @@ class Registrar {
|
|||
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
|
||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
|
||||
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "acl/acl.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
std::weak_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
|
||||
namespace mindspore {
|
||||
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
|
||||
std::mutex AclEnvGuard::global_acl_env_mutex_;
|
||||
|
||||
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
|
||||
|
@ -42,7 +42,7 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
|
|||
std::shared_ptr<AclEnvGuard> acl_env;
|
||||
|
||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||
acl_env = global_acl_env_.lock();
|
||||
acl_env = global_acl_env_;
|
||||
if (acl_env != nullptr) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
} else {
|
||||
|
@ -57,4 +57,4 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
|
|||
}
|
||||
return acl_env;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <mutex>
|
||||
#include "acl/acl_base.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class __attribute__((visibility("default"))) AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(std::string_view cfg_file);
|
||||
|
@ -29,10 +29,10 @@ class __attribute__((visibility("default"))) AclEnvGuard {
|
|||
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
|
||||
|
||||
private:
|
||||
static std::weak_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::shared_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
aclError errno_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H
|
||||
|
|
|
@ -16,53 +16,50 @@
|
|||
#include "cxx_api/graph/acl/acl_graph_impl.h"
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
|
||||
|
||||
AclGraphImpl::AclGraphImpl()
|
||||
: init_flag_(false),
|
||||
load_flag_(false),
|
||||
device_type_("AscendCL"),
|
||||
device_id_(Context::Instance().GetDeviceID()),
|
||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
||||
context_(nullptr),
|
||||
acl_env_(nullptr) {}
|
||||
|
||||
AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
|
||||
|
||||
Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
return model_process_.PredictFromHost(inputs, outputs);
|
||||
}
|
||||
|
||||
Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> AclGraphImpl::GetInputs() {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
return {};
|
||||
}
|
||||
|
||||
return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return model_process_.GetInputs();
|
||||
}
|
||||
|
||||
Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> AclGraphImpl::GetOutputs() {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
return {};
|
||||
}
|
||||
|
||||
return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return model_process_.GetOutputs();
|
||||
}
|
||||
|
||||
Status AclGraphImpl::LoadAclModel(Buffer om_data) {
|
||||
|
@ -72,44 +69,44 @@ Status AclGraphImpl::LoadAclModel(Buffer om_data) {
|
|||
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
// acl init model resource
|
||||
model_process_.set_model_id(acl_model_id);
|
||||
Status ret = model_process_.PreInitModelResource();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
(void)aclmdlUnload(acl_model_id);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Load acl model success.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::InitEnv() {
|
||||
if (init_flag_) {
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
acl_env_ = AclEnvGuard::GetAclEnv("");
|
||||
if (acl_env_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Acl init failed.";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
aclError ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
MS_LOG(INFO) << "Open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create context failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
MS_LOG(INFO) << "Create context success";
|
||||
|
||||
|
@ -117,7 +114,7 @@ Status AclGraphImpl::InitEnv() {
|
|||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl get run mode failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
|
@ -125,24 +122,24 @@ Status AclGraphImpl::InitEnv() {
|
|||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
Status ret = model_process_.UnLoad();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Unload model inner failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (context_ != nullptr) {
|
||||
|
@ -161,16 +158,16 @@ Status AclGraphImpl::FinalizeEnv() {
|
|||
MS_LOG(INFO) << "End to reset device " << device_id_;
|
||||
|
||||
init_flag_ = false;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::Load() {
|
||||
// check graph type
|
||||
if (graph_->ModelType() != ModelType::kOM) {
|
||||
Status ret = ConvertToOM();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load Failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -180,15 +177,15 @@ Status AclGraphImpl::Load() {
|
|||
|
||||
// init
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// load model
|
||||
if (!load_flag_) {
|
||||
ret = LoadAclModel(om_data);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load acl model failed.";
|
||||
return ret;
|
||||
}
|
||||
|
@ -198,24 +195,24 @@ Status AclGraphImpl::Load() {
|
|||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::ConvertToOM() {
|
||||
MS_LOG(INFO) << "Start convert to om model.";
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid graph_ is null.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
auto &graph_data = GraphImpl::MutableGraphData();
|
||||
MS_EXCEPTION_IF_NULL(graph_data);
|
||||
if (graph_->ModelType() == ModelType::kOM) {
|
||||
MS_LOG(INFO) << "This model has been built, skip.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
} else if (graph_->ModelType() == ModelType::kMindIR) {
|
||||
auto func_graph = graph_data->GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -223,13 +220,13 @@ Status AclGraphImpl::ConvertToOM() {
|
|||
Buffer om_data = model_converter.LoadMindIR(func_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to OM failed.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM);
|
||||
MS_LOG(INFO) << "Convert MindIR to OM success.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,18 +27,16 @@
|
|||
#include "cxx_api/graph/graph_impl.h"
|
||||
#include "cxx_api/factory.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class AclGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
AclGraphImpl();
|
||||
~AclGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
private:
|
||||
Status ConvertToOM();
|
||||
|
@ -56,5 +54,5 @@ class AclGraphImpl : public GraphCell::GraphImpl {
|
|||
|
||||
ModelProcess model_process_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
||||
|
|
|
@ -20,17 +20,19 @@
|
|||
#include <map>
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
static DataType TransToApiType(aclDataType data_type) {
|
||||
static const std::map<aclDataType, api::DataType> data_type_map = {
|
||||
{ACL_FLOAT16, api::kMsFloat16}, {ACL_FLOAT, api::kMsFloat32}, {ACL_DOUBLE, api::kMsFloat64},
|
||||
{ACL_INT8, api::kMsInt8}, {ACL_INT16, api::kMsInt16}, {ACL_INT32, api::kMsInt32},
|
||||
{ACL_INT64, api::kMsInt64}, {ACL_UINT8, api::kMsUint8}, {ACL_UINT16, api::kMsUint16},
|
||||
{ACL_UINT32, api::kMsUint32}, {ACL_UINT64, api::kMsUint64}, {ACL_BOOL, api::kMsBool},
|
||||
static const std::map<aclDataType, enum DataType> data_type_map = {
|
||||
{ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32},
|
||||
{ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8},
|
||||
{ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32},
|
||||
{ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8},
|
||||
{ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32},
|
||||
{ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool},
|
||||
};
|
||||
auto it = data_type_map.find(data_type);
|
||||
if (it == data_type_map.end()) {
|
||||
return api::kInvalidDataType;
|
||||
return DataType::kTypeUnknown;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
|
@ -51,7 +53,7 @@ inline static void PushbackIfNotNull(U *vec, T &&item) {
|
|||
}
|
||||
|
||||
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types,
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
|
||||
std::vector<size_t> *mem_sizes) {
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
|
@ -66,41 +68,69 @@ static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_lis
|
|||
}
|
||||
}
|
||||
|
||||
static std::string ShapeToString(const std::vector<int64_t> &shape) {
|
||||
std::string result = "[";
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
result += std::to_string(shape[i]);
|
||||
if (i + 1 < shape.size()) {
|
||||
result += ", ";
|
||||
}
|
||||
}
|
||||
result += "]";
|
||||
return result;
|
||||
}
|
||||
|
||||
Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list,
|
||||
std::vector<MSTensor> *tensor_list) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
std::vector<std::string> names;
|
||||
std::vector<std::vector<int64_t>> shapes;
|
||||
std::vector<enum DataType> data_types;
|
||||
std::vector<size_t> mem_sizes;
|
||||
|
||||
ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types, &mem_sizes);
|
||||
tensor_list->clear();
|
||||
if (names.size() != acl_tensor_list.size() || shapes.size() != acl_tensor_list.size() ||
|
||||
data_types.size() != acl_tensor_list.size() || mem_sizes.size() != acl_tensor_list.size()) {
|
||||
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names.size() << " shapes size " << shapes.size()
|
||||
<< " data types size " << data_types.size() << " mem sizes size " << mem_sizes.size()
|
||||
<< " acl_tensor_list size " << acl_tensor_list.size();
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]);
|
||||
auto ret = aclrtMemcpy((*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(),
|
||||
acl_tensor_list[i].device_data, acl_tensor_list[i].buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << acl_tensor_list[i].buffer_size;
|
||||
return kMCFailed;
|
||||
}
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::PreInitModelResource() {
|
||||
model_desc_ = aclmdlCreateDesc();
|
||||
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Read model desc failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
Status ret = InitInputsBuffer();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Create input buffer failed";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
ret = InitOutputsBuffer();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Create output buffer failed";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t *model_id) {
|
||||
MS_EXCEPTION_IF_NULL(model_id);
|
||||
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name;
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model success " << file_name;
|
||||
model_id_ = *model_id;
|
||||
if (PreInitModelResource() != SUCCESS) {
|
||||
aclmdlUnload(model_id_);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed, file name is " << file_name;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::InitInputsBuffer() {
|
||||
|
@ -113,8 +143,8 @@ Status ModelProcess::InitInputsBuffer() {
|
|||
if (!is_run_on_device_) { // need to copy input/output to/from device
|
||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device input buffer faild , input size " << buffer_size;
|
||||
return FAILED;
|
||||
MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,7 +155,7 @@ Status ModelProcess::InitInputsBuffer() {
|
|||
if (!is_run_on_device_) {
|
||||
aclrtFree(data_mem_buffer);
|
||||
}
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
|
@ -137,7 +167,7 @@ Status ModelProcess::InitInputsBuffer() {
|
|||
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model inputs success";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
|
||||
|
@ -154,14 +184,14 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
|
|||
if (!is_run_on_device_) {
|
||||
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMallocHost(data_mem_buffer, buffer_size);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,16 +199,16 @@ Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size
|
|||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::InitOutputsBuffer() {
|
||||
|
@ -186,7 +216,7 @@ Status ModelProcess::InitOutputsBuffer() {
|
|||
outputs_ = aclmdlCreateDataset();
|
||||
if (outputs_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create input dataset failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
||||
MS_LOG(INFO) << "output_size = " << output_size;
|
||||
|
@ -194,9 +224,9 @@ Status ModelProcess::InitOutputsBuffer() {
|
|||
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
|
||||
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != SUCCESS) {
|
||||
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != kSuccess) {
|
||||
MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
|
||||
|
@ -207,7 +237,7 @@ Status ModelProcess::InitOutputsBuffer() {
|
|||
} else {
|
||||
aclrtFreeHost(data_mem_buffer);
|
||||
}
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
|
@ -219,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() {
|
|||
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model output success";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataset() {
|
||||
|
@ -273,50 +303,60 @@ Status ModelProcess::UnLoad() {
|
|||
auto ret = aclmdlUnload(model_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
if (model_desc_ != nullptr) {
|
||||
ret = aclmdlDestroyDesc(model_desc_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
model_desc_ = nullptr;
|
||||
}
|
||||
DestroyInputsBuffer();
|
||||
DestroyOutputsBuffer();
|
||||
MS_LOG(INFO) << "End unload model " << model_id_;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
|
||||
Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
if (inputs.size() != input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
if (inputs[i].Shape() != input_infos_[i].dims) {
|
||||
MS_LOG(INFO) << "Note: input " << i << " shape not match, required " << ShapeToString(input_infos_[i].dims)
|
||||
<< ", given " << ShapeToString(inputs[i].Shape());
|
||||
}
|
||||
|
||||
if (inputs[i].DataType() != TransToApiType(input_infos_[i].data_type)) {
|
||||
MS_LOG(INFO) << "Note: input " << i << " data type not match, required "
|
||||
<< TransToApiType(input_infos_[i].data_type) << ", given " << inputs[i].DataType();
|
||||
}
|
||||
|
||||
if (inputs[i].DataSize() != input_infos_[i].buffer_size) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
||||
MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
||||
<< ", given count " << inputs[i].DataSize();
|
||||
return INVALID_INPUTS;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
}
|
||||
// copy inputs
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
const auto &info = input_infos_[i];
|
||||
const auto &input = inputs[i];
|
||||
const void *data = input.Data();
|
||||
auto input = inputs[i];
|
||||
const void *data = input.MutableData();
|
||||
|
||||
void *input_buffer = nullptr;
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
input_buffer = info.device_data;
|
||||
} else {
|
||||
|
@ -325,23 +365,23 @@ Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
|
|||
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status ModelProcess::PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError acl_ret;
|
||||
Status ret = CheckAndInitInput(inputs);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "check or init input failed";
|
||||
DestroyInputsDataset();
|
||||
return ret; // forward status error
|
||||
|
@ -361,50 +401,48 @@ Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vec
|
|||
DestroyInputsDataset();
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
ret = BuildOutputs(outputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build outputs faield";
|
||||
return FAILED;
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build outputs failed";
|
||||
return ret;
|
||||
}
|
||||
MS_LOG(INFO) << "excute model success";
|
||||
return SUCCESS;
|
||||
MS_LOG(INFO) << "Execute model success";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) {
|
||||
Status ModelProcess::BuildOutputs(std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError ret;
|
||||
// copy outputs
|
||||
outputs->clear();
|
||||
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
const auto &info = output_infos_[i];
|
||||
outputs->emplace_back(Buffer());
|
||||
auto output = outputs->rbegin();
|
||||
if (!output->ResizeData(info.buffer_size)) {
|
||||
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << info.buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
auto inner_outputs = GetOutputs();
|
||||
if (inner_outputs.size() != output_infos_.size()) {
|
||||
MS_LOG(ERROR) << "Invalid inner outputs size " << inner_outputs.size() << " do not match device output infos size "
|
||||
<< output_infos_.size();
|
||||
return kMCFailed;
|
||||
}
|
||||
return SUCCESS;
|
||||
(*outputs) = inner_outputs;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes);
|
||||
return SUCCESS;
|
||||
std::vector<MSTensor> ModelProcess::GetInputs() {
|
||||
Status ret = ConstructTensors(input_infos_, &input_tensors_);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ConstructTensors failed.";
|
||||
input_tensors_.clear();
|
||||
}
|
||||
|
||||
return input_tensors_;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes);
|
||||
return SUCCESS;
|
||||
std::vector<MSTensor> ModelProcess::GetOutputs() {
|
||||
Status ret = ConstructTensors(output_infos_, &output_tensors_);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ConstructTensors failed.";
|
||||
output_tensors_.clear();
|
||||
}
|
||||
|
||||
return output_tensors_;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
struct AclTensorInfo {
|
||||
void *device_data;
|
||||
size_t buffer_size;
|
||||
|
@ -45,14 +45,12 @@ class ModelProcess {
|
|||
input_infos_(),
|
||||
output_infos_() {}
|
||||
~ModelProcess() {}
|
||||
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
|
||||
|
||||
Status UnLoad();
|
||||
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
Status PredictFromHost(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
Status PreInitModelResource();
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
// override this method to avoid request/reply data copy
|
||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
||||
|
@ -62,8 +60,9 @@ class ModelProcess {
|
|||
|
||||
private:
|
||||
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
Status CheckAndInitInput(const std::vector<Buffer> &inputs);
|
||||
Status BuildOutputs(std::vector<Buffer> *outputs);
|
||||
Status CheckAndInitInput(const std::vector<MSTensor> &inputs);
|
||||
Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list);
|
||||
Status BuildOutputs(std::vector<MSTensor> *outputs);
|
||||
Status InitInputsBuffer();
|
||||
Status InitOutputsBuffer();
|
||||
|
||||
|
@ -80,7 +79,9 @@ class ModelProcess {
|
|||
aclmdlDataset *outputs_;
|
||||
std::vector<AclTensorInfo> input_infos_;
|
||||
std::vector<AclTensorInfo> output_infos_;
|
||||
std::vector<MSTensor> input_tensors_;
|
||||
std::vector<MSTensor> output_tensors_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
|
||||
|
|
|
@ -25,91 +25,51 @@
|
|||
#include "backend/session/executor_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
|
||||
|
||||
AscendGraphImpl::AscendGraphImpl()
|
||||
: session_impl_(nullptr),
|
||||
graph_id_(0),
|
||||
device_type_("Ascend"),
|
||||
device_id_(Context::Instance().GetDeviceID()),
|
||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
||||
context_(nullptr),
|
||||
inputs_(),
|
||||
outputs_(),
|
||||
inputs_info_(),
|
||||
outputs_info_(),
|
||||
input_names_(),
|
||||
output_names_(),
|
||||
init_flag_(false),
|
||||
load_flag_(false) {}
|
||||
|
||||
AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); }
|
||||
AscendGraphImpl::~AscendGraphImpl() {}
|
||||
|
||||
Status AscendGraphImpl::InitEnv() {
|
||||
if (init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return FAILED;
|
||||
MS_LOG(INFO) << "Start to init env.";
|
||||
env_guard_ = MsEnvGuard::GetEnv(device_id_);
|
||||
if (env_guard_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Env init failed.";
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
|
||||
<< " is available.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
session_impl_->Init(device_id_);
|
||||
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MS_LOG_INFO << "Start finalize env";
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
{
|
||||
PythonEnvGuard guard;
|
||||
if (!context::CloseTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "CloseTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
init_flag_ = false;
|
||||
MS_LOG(INFO) << "End finalize env";
|
||||
return SUCCESS;
|
||||
MS_LOG(INFO) << "InitEnv success.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,104 +88,104 @@ Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &i
|
|||
MS_ASSERT(session_impl_ != nullptr);
|
||||
std::string error_msg;
|
||||
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
|
||||
return Status(INVALID_INPUTS, error_msg);
|
||||
return Status(kMCInvalidInput, error_msg);
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||
Status AscendGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
|
||||
MS_EXCEPTION_IF_NULL(reply);
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set Ascend rtCtx failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
auto &item = request[i];
|
||||
auto input = inputs_[i];
|
||||
auto item = request[i];
|
||||
auto input = inputs_info_[i];
|
||||
if (input->Size() != item.DataSize()) {
|
||||
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||
<< input->Size();
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Tensor copy failed";
|
||||
return FAILED;
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.MutableData(), item.DataSize());
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "MSTensor copy failed";
|
||||
return kMCFailed;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
last_inputs_ = inputs;
|
||||
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
last_outputs_ = outputs;
|
||||
reply->clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||
return SUCCESS;
|
||||
*reply = GetOutputs();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
GraphUtils::ClearIfNotNull(names);
|
||||
GraphUtils::ClearIfNotNull(shapes);
|
||||
GraphUtils::ClearIfNotNull(data_types);
|
||||
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
auto &tensor = inputs_[i];
|
||||
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
|
||||
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||
std::vector<MSTensor> result(inputs_info_.size());
|
||||
for (size_t i = 0; i < inputs_info_.size(); ++i) {
|
||||
auto &tensor = inputs_info_[i];
|
||||
void *data = nullptr;
|
||||
size_t data_size = tensor->Size();
|
||||
if (i < last_inputs_.size()) {
|
||||
data = last_inputs_[i]->data_c();
|
||||
data_size = last_inputs_[i]->Size();
|
||||
}
|
||||
result[i] =
|
||||
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
|
||||
}
|
||||
return SUCCESS;
|
||||
return result;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
GraphUtils::ClearIfNotNull(names);
|
||||
GraphUtils::ClearIfNotNull(shapes);
|
||||
GraphUtils::ClearIfNotNull(data_types);
|
||||
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||
auto &tensor = outputs_[i];
|
||||
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
|
||||
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||
std::vector<MSTensor> result(outputs_info_.size());
|
||||
for (size_t i = 0; i < outputs_info_.size(); ++i) {
|
||||
auto &tensor = outputs_info_[i];
|
||||
void *data = nullptr;
|
||||
size_t data_size = tensor->Size();
|
||||
if (i < last_outputs_.size()) {
|
||||
data = last_outputs_[i]->data_c();
|
||||
data_size = last_outputs_[i]->Size();
|
||||
}
|
||||
result[i] =
|
||||
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return result;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::Load() {
|
||||
// check graph type
|
||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||
return INVALID_INPUTS;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
const auto &graph_data = GraphImpl::MutableGraphData();
|
||||
|
@ -234,34 +194,34 @@ Status AscendGraphImpl::Load() {
|
|||
|
||||
// init
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// load model
|
||||
if (!load_flag_) {
|
||||
ret = CompileGraph(func_graph);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
|
||||
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model inputs info failed";
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model outputs info failed";
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
// save d context
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Load model success";
|
||||
|
@ -271,44 +231,112 @@ Status AscendGraphImpl::Load() {
|
|||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
if (inputs.size() != inputs_info_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
if (inputs[i].DataSize() != inputs_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||
<< inputs[i].DataSize();
|
||||
return INVALID_INPUTS;
|
||||
for (size_t i = 0; i < inputs_info_.size(); ++i) {
|
||||
if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
|
||||
<< ", given count " << inputs[i].DataSize();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
}
|
||||
if (ExecuteModel(inputs, outputs) != SUCCESS) {
|
||||
|
||||
Status ret = ExecuteModel(inputs, outputs);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
if (outputs_.size() != outputs->size()) {
|
||||
if (outputs_info_.size() != outputs->size()) {
|
||||
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
|
||||
<< outputs_.size();
|
||||
return FAILED;
|
||||
<< outputs_info_.size();
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
|
||||
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
|
||||
MS_LOG(INFO) << "Start to init env.";
|
||||
device_id_ = device_id;
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
errno_ = kMCFailed;
|
||||
return;
|
||||
}
|
||||
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
||||
auto ret = rtSetDevice(device_id_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "InitEnv success.";
|
||||
errno_ = kSuccess;
|
||||
}
|
||||
|
||||
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
|
||||
MS_LOG(INFO) << "Start finalize env";
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
errno_ = kMCFailed;
|
||||
return;
|
||||
}
|
||||
|
||||
auto ret = rtDeviceReset(device_id_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
||||
errno_ = kSuccess;
|
||||
MS_LOG(INFO) << "End finalize env";
|
||||
}
|
||||
|
||||
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
|
||||
std::shared_ptr<MsEnvGuard> acl_env;
|
||||
std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
|
||||
acl_env = global_ms_env_.lock();
|
||||
if (acl_env != nullptr) {
|
||||
MS_LOG(INFO) << "Env has been initialized, skip.";
|
||||
} else {
|
||||
acl_env = std::make_shared<MsEnvGuard>(device_id);
|
||||
if (acl_env->GetErrno() != kSuccess) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return nullptr;
|
||||
}
|
||||
global_ms_env_ = acl_env;
|
||||
MS_LOG(INFO) << "Env init success";
|
||||
}
|
||||
return acl_env;
|
||||
}
|
||||
|
||||
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
||||
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,40 +28,56 @@
|
|||
#include "ir/anf.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "runtime/context.h"
|
||||
#include "cxx_api/graph/graph_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class AscendGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
AscendGraphImpl();
|
||||
~AscendGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
private:
|
||||
class MsEnvGuard;
|
||||
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
|
||||
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
std::shared_ptr<session::SessionBasic> session_impl_;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
uint32_t device_id_;
|
||||
rtContext_t context_;
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<tensor::TensorPtr> inputs_info_;
|
||||
std::vector<tensor::TensorPtr> outputs_info_;
|
||||
std::vector<tensor::TensorPtr> last_inputs_;
|
||||
std::vector<tensor::TensorPtr> last_outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
|
||||
std::shared_ptr<MsEnvGuard> env_guard_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
class AscendGraphImpl::MsEnvGuard {
|
||||
public:
|
||||
explicit MsEnvGuard(uint32_t device_id);
|
||||
~MsEnvGuard();
|
||||
Status GetErrno() const { return errno_; }
|
||||
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
|
||||
|
||||
private:
|
||||
static std::weak_ptr<MsEnvGuard> global_ms_env_;
|
||||
static std::mutex global_ms_env_mutex_;
|
||||
|
||||
Status errno_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
|
||||
|
|
|
@ -23,15 +23,15 @@
|
|||
#include "backend/session/executor_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
|
||||
|
||||
GPUGraphImpl::GPUGraphImpl()
|
||||
: session_impl_(nullptr),
|
||||
graph_id_(0),
|
||||
device_id_(Context::Instance().GetDeviceID()),
|
||||
inputs_(),
|
||||
outputs_(),
|
||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
||||
inputs_info_(),
|
||||
outputs_info_(),
|
||||
input_names_(),
|
||||
output_names_(),
|
||||
init_flag_(false),
|
||||
|
@ -40,13 +40,13 @@ GPUGraphImpl::GPUGraphImpl()
|
|||
Status GPUGraphImpl::InitEnv() {
|
||||
if (init_flag_) {
|
||||
MS_LOG(WARNING) << "Initialized again, return success.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
|
@ -57,18 +57,18 @@ Status GPUGraphImpl::InitEnv() {
|
|||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice
|
||||
<< " is available.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
session_impl_->Init(device_id_);
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
MS_LOG(WARNING) << "Never initialize before, return success";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG_INFO << "Start finalize env";
|
||||
|
@ -77,14 +77,14 @@ Status GPUGraphImpl::FinalizeEnv() {
|
|||
|
||||
init_flag_ = false;
|
||||
MS_LOG(INFO) << "End finalize env";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::Load() {
|
||||
// check graph type
|
||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||
return INVALID_INPUTS;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
const auto &graph_data = GraphImpl::MutableGraphData();
|
||||
|
@ -93,38 +93,38 @@ Status GPUGraphImpl::Load() {
|
|||
|
||||
// init
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
return kMCDeviceError;
|
||||
}
|
||||
|
||||
ret = CompileGraph(func_graph);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
|
||||
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model inputs info failed";
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model outputs info failed";
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
load_flag_ = true;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,118 +139,118 @@ std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::
|
|||
}
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||
Status GPUGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vector<MSTensor> *reply) {
|
||||
MS_EXCEPTION_IF_NULL(reply);
|
||||
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
auto &item = request[i];
|
||||
auto input = inputs_[i];
|
||||
auto input = inputs_info_[i];
|
||||
if (input->Size() != item.DataSize()) {
|
||||
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||
<< input->Size();
|
||||
return FAILED;
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||
if (ret != SUCCESS) {
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data().get(), item.DataSize());
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Tensor copy failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
last_inputs_ = inputs;
|
||||
std::vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
last_outputs_ = outputs;
|
||||
reply->clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||
return SUCCESS;
|
||||
*reply = GetOutputs();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
if (inputs.size() != inputs_info_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
if (inputs[i].DataSize() != inputs_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||
<< inputs[i].DataSize();
|
||||
return INVALID_INPUTS;
|
||||
for (size_t i = 0; i < inputs_info_.size(); ++i) {
|
||||
if (inputs[i].DataSize() != inputs_info_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size()
|
||||
<< ", given count " << inputs[i].DataSize();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
}
|
||||
if (ExecuteModel(inputs, outputs) != SUCCESS) {
|
||||
if (ExecuteModel(inputs, outputs) != kSuccess) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
if (outputs_.size() != outputs->size()) {
|
||||
if (outputs_info_.size() != outputs->size()) {
|
||||
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
|
||||
<< outputs_.size();
|
||||
return FAILED;
|
||||
<< outputs_info_.size();
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> GPUGraphImpl::GetInputs() {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
GraphUtils::ClearIfNotNull(names);
|
||||
GraphUtils::ClearIfNotNull(shapes);
|
||||
GraphUtils::ClearIfNotNull(data_types);
|
||||
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
auto &tensor = inputs_[i];
|
||||
GraphUtils::PushbackIfNotNull(names, input_names_[i]);
|
||||
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||
std::vector<MSTensor> result(inputs_info_.size());
|
||||
for (size_t i = 0; i < inputs_info_.size(); ++i) {
|
||||
auto &tensor = inputs_info_[i];
|
||||
void *data = nullptr;
|
||||
size_t data_size = tensor->Size();
|
||||
if (i < last_inputs_.size()) {
|
||||
data = last_inputs_[i]->data_c();
|
||||
data_size = last_inputs_[i]->Size();
|
||||
}
|
||||
result[i] =
|
||||
MSTensor(input_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
|
||||
}
|
||||
return SUCCESS;
|
||||
return result;
|
||||
}
|
||||
|
||||
Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
GraphUtils::ClearIfNotNull(names);
|
||||
GraphUtils::ClearIfNotNull(shapes);
|
||||
GraphUtils::ClearIfNotNull(data_types);
|
||||
GraphUtils::ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||
auto &tensor = outputs_[i];
|
||||
GraphUtils::PushbackIfNotNull(names, output_names_[i]);
|
||||
GraphUtils::PushbackIfNotNull(shapes, tensor->shape());
|
||||
GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type()));
|
||||
GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size());
|
||||
std::vector<MSTensor> result(outputs_info_.size());
|
||||
for (size_t i = 0; i < outputs_info_.size(); ++i) {
|
||||
auto &tensor = outputs_info_[i];
|
||||
void *data = nullptr;
|
||||
size_t data_size = tensor->Size();
|
||||
if (i < last_outputs_.size()) {
|
||||
data = last_outputs_[i]->data_c();
|
||||
data_size = last_outputs_[i]->Size();
|
||||
}
|
||||
result[i] =
|
||||
MSTensor(output_names_[i], static_cast<enum DataType>(tensor->data_type()), tensor->shape(), data, data_size);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,20 +25,17 @@
|
|||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "cxx_api/graph/graph_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class GPUGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
GPUGraphImpl();
|
||||
~GPUGraphImpl() override = default;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
private:
|
||||
Status InitEnv();
|
||||
|
@ -46,14 +43,16 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
|
|||
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
|
||||
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
Status ExecuteModel(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
std::shared_ptr<session::SessionBasic> session_impl_;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
uint32_t device_id_;
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<tensor::TensorPtr> inputs_info_;
|
||||
std::vector<tensor::TensorPtr> outputs_info_;
|
||||
std::vector<tensor::TensorPtr> last_inputs_;
|
||||
std::vector<tensor::TensorPtr> last_outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool init_flag_;
|
||||
|
@ -63,5 +62,5 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
|
|||
uint32_t batch_size_;
|
||||
uint32_t workspace_size_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H
|
||||
|
|
|
@ -17,15 +17,19 @@
|
|||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
Graph::~Graph() {}
|
||||
|
||||
Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
|
||||
|
||||
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
||||
|
||||
ModelType Graph::ModelType() const {
|
||||
MS_EXCEPTION_IF_NULL(graph_data_);
|
||||
return graph_data_->ModelType();
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "framework/common/helper/model_helper.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
|
||||
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
|
||||
if (model_type != ModelType::kMindIR) {
|
||||
|
@ -72,4 +72,4 @@ Buffer Graph::GraphData::GetOMData() const {
|
|||
|
||||
return om_data_;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "include/api/types.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class Graph::GraphData {
|
||||
public:
|
||||
GraphData();
|
||||
|
@ -46,5 +46,5 @@ class Graph::GraphData {
|
|||
Buffer om_data_;
|
||||
enum ModelType model_type_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class GraphCell::GraphImpl {
|
||||
public:
|
||||
GraphImpl() = default;
|
||||
|
@ -35,17 +35,14 @@ class GraphCell::GraphImpl {
|
|||
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
|
||||
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
|
||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
||||
virtual Status Load() = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "include/api/types.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class GraphUtils {
|
||||
public:
|
||||
static DataType TransTypeId2InferDataType(TypeId type_id) {
|
||||
const std::map<TypeId, api::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
||||
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
||||
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
||||
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
||||
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
||||
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
||||
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
||||
};
|
||||
|
||||
auto it = id2type_map.find(type_id);
|
||||
if (it != id2type_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "Unsupported data id " << type_id;
|
||||
return api::kMsUnknown;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline static void ClearIfNotNull(T *vec) {
|
||||
if (vec != nullptr) {
|
||||
vec->clear();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
||||
if (vec != nullptr) {
|
||||
vec->emplace_back(item);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H
|
|
@ -16,47 +16,53 @@
|
|||
|
||||
#include "cxx_api/model/acl/acl_model.h"
|
||||
#include <memory>
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
|
||||
|
||||
Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
|
||||
Status AclModel::Build() {
|
||||
MS_LOG(INFO) << "Start build model.";
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(options_map);
|
||||
std::string options_str = GenerateOptionsStr(options_map);
|
||||
MS_EXCEPTION_IF_NULL(options);
|
||||
if (graph_cell_ != nullptr && options_str == options_str_) {
|
||||
|
||||
if (graph_cell_ != nullptr) {
|
||||
MS_LOG(INFO) << "This model has been built, skip.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) {
|
||||
MS_LOG(INFO) << "Note: Load om model and all build options will be ignored.";
|
||||
graph_cell_ = std::make_shared<GraphCell>(graph_);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
if (!options_map.empty()) {
|
||||
MS_LOG(WARNING) << "All build options will be ignored.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
|
||||
MS_EXCEPTION_IF_NULL(options);
|
||||
std::string options_key = options->GenAclOptionsKey();
|
||||
std::shared_ptr<Graph> graph;
|
||||
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
|
||||
MS_LOG(INFO) << "This options has been built, read cache.";
|
||||
graph = iter->second;
|
||||
} else {
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
model_converter_.set_options(options.get());
|
||||
auto om_data = model_converter_.LoadMindIR(func_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return kMCFailed;
|
||||
}
|
||||
return SUCCESS;
|
||||
graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
dynamic_size_graph_map_[options_key] = graph;
|
||||
}
|
||||
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
model_converter_.set_options(options.get());
|
||||
auto om_data = model_converter_.LoadMindIR(func_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
return ret;
|
||||
}
|
||||
|
@ -64,64 +70,97 @@ Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
|
|||
// save result
|
||||
graph_cell_ = graph_cell;
|
||||
options_ = std::move(options);
|
||||
options_str_ = options_str;
|
||||
MS_LOG(INFO) << "Build model success.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
MS_LOG(INFO) << "Start to resize model.";
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
if (graph_->ModelType() == ModelType::kOM) {
|
||||
MS_LOG(ERROR) << "OM model is not supported to resize model.";
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
auto origin_inputs = GetInputs();
|
||||
if (inputs.size() != origin_inputs.size()) {
|
||||
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
if (inputs.size() != dims.size()) {
|
||||
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
if (model_context_ == nullptr) {
|
||||
model_context_ = std::make_shared<ModelContext>();
|
||||
}
|
||||
|
||||
std::string input_shape_option;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (inputs[i].Name() != origin_inputs[i].Name()) {
|
||||
MS_LOG(ERROR) << "Invalid inputs " << i << " name " << inputs[i].Name() << " not match model input name "
|
||||
<< origin_inputs[i].Name();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
input_shape_option += inputs[i].Name() + ":";
|
||||
for (size_t j = 0; j < dims[i].size(); ++j) {
|
||||
input_shape_option += std::to_string(dims[i][j]);
|
||||
if (j + 1 < dims[i].size()) {
|
||||
input_shape_option += ",";
|
||||
}
|
||||
}
|
||||
if (i + 1 < inputs.size()) {
|
||||
input_shape_option += ";";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
|
||||
ModelContext::SetInputShape(model_context_, input_shape_option);
|
||||
auto graph_cell_bak = std::move(graph_cell_);
|
||||
auto ret = Build();
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(INFO) << "Resize build failed.";
|
||||
graph_cell_ = std::move(graph_cell_bak);
|
||||
return ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Resize success.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status AclModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status AclModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
if (graph_cell_ == nullptr) {
|
||||
MS_LOG(WARNING) << "Model has not been built, it will be built with default options";
|
||||
Status ret = Build({});
|
||||
if (ret != SUCCESS) {
|
||||
Status ret = Build();
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
Status ret = graph_cell_->Run(inputs, outputs);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status AclModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> AclModel::GetInputs() {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return graph_cell_->GetInputs();
|
||||
}
|
||||
|
||||
Status AclModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> AclModel::GetOutputs() {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return graph_cell_->GetOutputs();
|
||||
}
|
||||
|
||||
std::string AclModel::GenerateOptionsStr(const std::map<std::string, std::string> &options) {
|
||||
std::string ret;
|
||||
for (auto &[key, value] : options) {
|
||||
ret += key + "^" + value + "^^";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,30 +31,25 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class AclModel : public ModelImpl {
|
||||
public:
|
||||
AclModel() : model_converter_(), options_(nullptr), options_str_() {}
|
||||
AclModel() : model_converter_(), options_(nullptr) {}
|
||||
~AclModel() = default;
|
||||
|
||||
Status Build(const std::map<std::string, std::string> &options_map) override;
|
||||
Status Build() override;
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
private:
|
||||
static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options);
|
||||
|
||||
std::shared_ptr<GraphCell> graph_cell_;
|
||||
ModelConverter model_converter_;
|
||||
std::unique_ptr<AclModelOptions> options_;
|
||||
std::string options_str_;
|
||||
std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H
|
||||
|
|
|
@ -18,23 +18,31 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "external/ge/ge_api_types.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
static std::string ParseOption(const std::map<std::string, std::string> &options, const std::string &key) {
|
||||
auto iter = options.find(key);
|
||||
if (iter != options.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
namespace mindspore {
|
||||
static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{DataType::kNumberTypeFloat16, "FP16"},
|
||||
{DataType::kNumberTypeFloat32, "FP32"},
|
||||
{DataType::kNumberTypeUInt8, "UINT8"}};
|
||||
|
||||
AclModelOptions::AclModelOptions(const std::map<std::string, std::string> &options) {
|
||||
// to acl
|
||||
insert_op_cfg_path = ParseOption(options, kModelOptionInsertOpCfgPath);
|
||||
input_format = ParseOption(options, kModelOptionInputFormat);
|
||||
input_shape = ParseOption(options, kModelOptionInputShape);
|
||||
output_type = ParseOption(options, kModelOptionOutputType);
|
||||
precision_mode = ParseOption(options, kModelOptionPrecisionMode);
|
||||
op_select_impl_mode = ParseOption(options, kModelOptionOpSelectImplMode);
|
||||
AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||
if (context == nullptr) {
|
||||
return;
|
||||
}
|
||||
insert_op_cfg_path = ModelContext::GetInsertOpConfigPath(context);
|
||||
input_format = ModelContext::GetInputFormat(context);
|
||||
input_shape = ModelContext::GetInputShape(context);
|
||||
|
||||
auto out_type = ModelContext::GetOutputType(context);
|
||||
auto iter = kSupportedDtypeOptionMap.find(out_type);
|
||||
if (out_type == DataType::kTypeUnknown) {
|
||||
// do nothing
|
||||
} else if (iter == kSupportedDtypeOptionMap.end()) {
|
||||
MS_LOG(WARNING) << "Unsupported output type " << out_type << ", use FP32 as default.";
|
||||
} else {
|
||||
output_type = iter->second;
|
||||
}
|
||||
|
||||
precision_mode = ModelContext::GetPrecisionMode(context);
|
||||
op_select_impl_mode = ModelContext::GetOpSelectImplMode(context);
|
||||
}
|
||||
|
||||
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
|
||||
|
@ -69,4 +77,16 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
|
|||
}
|
||||
return {init_options, build_options};
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
|
||||
std::string AclModelOptions::GenAclOptionsKey() const {
|
||||
auto [init_options, build_options] = GenAclOptions();
|
||||
std::string key_str;
|
||||
for (auto &[key, value] : init_options) {
|
||||
key_str += key + "^" + value + "^^";
|
||||
}
|
||||
for (auto &[key, value] : build_options) {
|
||||
key_str += key + "^" + value + "^^";
|
||||
}
|
||||
return key_str;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,12 +20,13 @@
|
|||
#include <string>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
struct AclModelOptions {
|
||||
std::string output_node; // todo: at convert.cc::BuildGraph(), no atc options
|
||||
// build options
|
||||
std::string insert_op_cfg_path;
|
||||
std::string input_format;
|
||||
|
@ -35,12 +36,13 @@ struct AclModelOptions {
|
|||
std::string op_select_impl_mode;
|
||||
std::string soc_version = "Ascend310";
|
||||
|
||||
explicit AclModelOptions(const std::map<std::string, std::string> &options);
|
||||
explicit AclModelOptions(const std::shared_ptr<Context> &context);
|
||||
~AclModelOptions() = default;
|
||||
|
||||
// return tuple<init_options, build_options>
|
||||
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
|
||||
std::string GenAclOptionsKey() const;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H
|
||||
|
|
|
@ -22,9 +22,8 @@
|
|||
#include "include/api/serialization.h"
|
||||
#include "graph/model.h"
|
||||
#include "cxx_api/model/model_converter_utils/multi_process.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
||||
transform::TensorOrderMap res;
|
||||
|
@ -86,25 +85,25 @@ transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &
|
|||
para->set_name(name);
|
||||
}
|
||||
|
||||
transform::DfGraphConvertor convertor(anf_graph);
|
||||
transform::DfGraphConvertor converter(anf_graph);
|
||||
std::string net_id = "0";
|
||||
std::string init_graph = "init_subgraph." + net_id;
|
||||
std::string checkpoint_name = "save." + net_id;
|
||||
|
||||
convertor.set_training(false);
|
||||
(void)convertor.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
|
||||
(void)convertor.GenerateCheckpointGraph();
|
||||
if (convertor.ErrCode() != 0) {
|
||||
converter.set_training(false);
|
||||
(void)converter.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
|
||||
(void)converter.GenerateCheckpointGraph();
|
||||
if (converter.ErrCode() != 0) {
|
||||
transform::DfGraphManager::GetInstance().ClearGraph();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode();
|
||||
return nullptr;
|
||||
}
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), convertor.GetComputeGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), converter.GetComputeGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph());
|
||||
|
||||
transform::Status ret =
|
||||
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph());
|
||||
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph());
|
||||
if (ret == transform::Status::SUCCESS) {
|
||||
transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
|
||||
}
|
||||
|
@ -158,7 +157,7 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
|||
auto df_graph = ConvertFuncGraphToAIR(func_graph);
|
||||
if (df_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
ge::Model model;
|
||||
ge::Buffer model_data;
|
||||
|
@ -166,14 +165,14 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
|||
auto ge_ret = model.Save(model_data);
|
||||
if (ge_ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Save ge model to buffer failed.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
// send original model to child
|
||||
auto status = multi_process->SendMsg(model_data.data(), model_data.size());
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Send original model to child process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
// receive convert model result from child
|
||||
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
|
||||
|
@ -181,11 +180,11 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
|||
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
|
||||
};
|
||||
status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Receive result model from child process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
};
|
||||
auto child_process = [this](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
|
@ -196,25 +195,25 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
|||
return reinterpret_cast<uint8_t *>(model.MutableData());
|
||||
};
|
||||
auto status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Receive original model from parent process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
Buffer model_result = LoadAscendIRInner(model);
|
||||
if (model_result.DataSize() == 0) {
|
||||
MS_LOG_ERROR << "Convert model from MindIR to OM failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
// send result model to parent
|
||||
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Send result model to parent process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
};
|
||||
auto status = multi_process.MainProcess(parent_process, child_process);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Convert MindIR model to OM model failed";
|
||||
} else {
|
||||
MS_LOG_INFO << "Convert MindIR model to OM model success";
|
||||
|
@ -229,9 +228,9 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
|||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
// send original model to child
|
||||
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Send original model to child process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
// receive convert model result from child
|
||||
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
|
||||
|
@ -239,11 +238,11 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
|||
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
|
||||
};
|
||||
status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Receive result model from child process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
};
|
||||
auto child_process = [this](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
|
@ -254,25 +253,25 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
|||
return reinterpret_cast<uint8_t *>(model.MutableData());
|
||||
};
|
||||
auto status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Receive original model from parent process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
Buffer model_result = LoadAscendIRInner(model);
|
||||
if (model_result.DataSize() == 0) {
|
||||
MS_LOG_ERROR << "Convert model from AIR to OM failed";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
// send result model to parent
|
||||
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Send result model to parent process failed";
|
||||
return FAILED;
|
||||
return status;
|
||||
}
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
};
|
||||
auto status = multi_process.MainProcess(parent_process, child_process);
|
||||
if (!status.IsSuccess()) {
|
||||
if (status != kSuccess) {
|
||||
MS_LOG_ERROR << "Convert AIR model to OM model failed";
|
||||
} else {
|
||||
MS_LOG_INFO << "Convert AIR model to OM model success";
|
||||
|
@ -326,4 +325,4 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
|||
auto om_data = BuildAirModel(df_graph, init_options, build_options);
|
||||
return om_data;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "external/ge/ge_ir_build.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class ModelConverter {
|
||||
public:
|
||||
ModelConverter() : options_(nullptr) {}
|
||||
|
@ -46,6 +46,5 @@ class ModelConverter {
|
|||
Buffer LoadMindIRInner(const FuncGraphPtr &func_graph);
|
||||
Buffer LoadAscendIRInner(const Buffer &model_data);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
||||
|
|
|
@ -19,49 +19,45 @@
|
|||
#include "cxx_api/factory.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Status Model::Build(const std::map<std::string, std::string> &options) {
|
||||
namespace mindspore {
|
||||
Status Model::Build() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Build(options);
|
||||
return impl_->Build();
|
||||
}
|
||||
|
||||
Status Model::Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
|
||||
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Train(dataset, outputs);
|
||||
return impl_->Resize(inputs, dims);
|
||||
}
|
||||
|
||||
Status Model::Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Eval(dataset, outputs);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Predict(inputs, outputs);
|
||||
}
|
||||
|
||||
Status Model::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> Model::GetInputs() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return impl_->GetInputs();
|
||||
}
|
||||
|
||||
Status Model::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> Model::GetOutputs() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return impl_->GetOutputs();
|
||||
}
|
||||
|
||||
Model::Model(const GraphCell &graph_cell)
|
||||
: impl_(Factory<ModelImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
Model::Model(const GraphCell &graph_cell, const std::shared_ptr<Context> &model_context)
|
||||
: impl_(Factory<ModelImpl>::Instance().Create(mindspore::GlobalContext::GetGlobalDeviceTarget())) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed";
|
||||
MS_LOG(EXCEPTION) << "Create session type " << mindspore::GlobalContext::GetGlobalDeviceTarget() << " failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph_cell.GetGraph());
|
||||
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
|
||||
impl_->SetContext(model_context);
|
||||
}
|
||||
|
||||
Model::Model(const std::vector<Output> &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; }
|
||||
Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported feature.";
|
||||
}
|
||||
|
||||
Model::~Model() {}
|
||||
|
||||
|
@ -69,4 +65,4 @@ bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
|
|||
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
|
||||
}
|
||||
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "cxx_api/model/model_converter_utils/shared_memory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
namespace {
|
||||
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
|
||||
}
|
||||
|
@ -40,7 +39,7 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
|
|||
memory_size_ = kSharedMemorySize; // 100 MB
|
||||
SharedMemory shared_memory;
|
||||
ret = shared_memory.Create(memory_size_);
|
||||
if (!ret.IsSuccess()) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG_ERROR << "Create shared memory failed";
|
||||
return ret;
|
||||
}
|
||||
|
@ -48,10 +47,10 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall
|
|||
if (pid < 0) {
|
||||
shared_memory.Destroy();
|
||||
MS_LOG_ERROR << "Fork process to convert model failed";
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
ret = shared_memory.Attach();
|
||||
if (!ret.IsSuccess()) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
|
||||
return ret;
|
||||
}
|
||||
|
@ -87,12 +86,12 @@ Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
|
|||
Status ret;
|
||||
try {
|
||||
ret = parent_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG_ERROR << "Parent process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
|
||||
ret = FAILED;
|
||||
ret = kMEFailed;
|
||||
}
|
||||
stopped_ = true;
|
||||
send_msg_->stop = true;
|
||||
|
@ -108,7 +107,7 @@ void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
|
|||
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
|
||||
try {
|
||||
auto ret = child_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG_ERROR << "Child process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
|
@ -138,14 +137,14 @@ Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
|
|||
}
|
||||
if (peer_stopped_) {
|
||||
if (!send_msg_->read_finish_flag) {
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
break;
|
||||
}
|
||||
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
|
||||
}
|
||||
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
|
||||
|
@ -158,7 +157,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
|
|||
usleep(1000); // 1ms
|
||||
}
|
||||
if (peer_stopped_) {
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
if (msg_buffer == nullptr) {
|
||||
msg_len = receive_msg_->msg_total_len;
|
||||
|
@ -170,7 +169,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
|
|||
receive_msg_->read_finish_flag = true;
|
||||
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
|
||||
} while (msg_len > cur_offset);
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
|
||||
|
@ -200,6 +199,4 @@ void MultiProcess::HeartbeatThreadFuncInner() {
|
|||
usleep(100000); // sleep 100 ms
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
struct MessageFlag {
|
||||
uint64_t heartbeat = 0;
|
||||
uint64_t stop = false;
|
||||
|
@ -60,7 +59,5 @@ class MultiProcess {
|
|||
Status ParentProcess(ProcessFuncCall parent_process);
|
||||
void ChildProcess(ProcessFuncCall child_process);
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
||||
|
|
|
@ -20,26 +20,25 @@
|
|||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
Status SharedMemory::Create(uint64_t memory_size) {
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (shm_id_ == -1) {
|
||||
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
MS_LOG_INFO << "shmget success, shm id " << shm_id_;
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status SharedMemory::Attach() {
|
||||
void *shmat_addr = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr == reinterpret_cast<void *>(-1)) {
|
||||
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
void SharedMemory::Detach() {
|
||||
|
@ -63,5 +62,4 @@ void SharedMemory::Destroy() {
|
|||
MS_LOG_ERROR << errMsg;
|
||||
}
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class SharedMemory {
|
||||
public:
|
||||
Status Create(uint64_t memory_size);
|
||||
|
@ -33,7 +32,5 @@ class SharedMemory {
|
|||
int shm_id_ = -1;
|
||||
uint8_t *shmat_addr_ = nullptr;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
||||
|
|
|
@ -21,28 +21,26 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
class ModelImpl {
|
||||
public:
|
||||
ModelImpl() = default;
|
||||
virtual ~ModelImpl() = default;
|
||||
|
||||
virtual Status Build(const std::map<std::string, std::string> &options) = 0;
|
||||
virtual Status Build() = 0;
|
||||
virtual Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) = 0;
|
||||
|
||||
virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
|
||||
virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
|
||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||
|
||||
protected:
|
||||
Status Load(const std::shared_ptr<GraphCell> &graph_cell) {
|
||||
|
@ -61,11 +59,16 @@ class ModelImpl {
|
|||
}
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::shared_ptr<Context> model_context_;
|
||||
|
||||
private:
|
||||
friend class Model;
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
void SetContext(const std::shared_ptr<Context> &model_context) {
|
||||
if (model_context != nullptr) {
|
||||
model_context_ = std::make_shared<Context>(*model_context);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H
|
||||
|
|
|
@ -16,18 +16,78 @@
|
|||
|
||||
#include "cxx_api/model/ms/ms_model.h"
|
||||
#include <memory>
|
||||
#include "include/api/context.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
||||
API_FACTORY_REG(ModelImpl, GPU, MsModel);
|
||||
|
||||
Status MsModel::Build(const std::map<std::string, std::string> &) {
|
||||
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
|
||||
std::string shape_key;
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
shape_key += std::to_string(i) + ":";
|
||||
for (size_t j = 0; j < dims[i].size(); ++j) {
|
||||
shape_key += std::to_string(dims[i][j]);
|
||||
if (j + 1 < dims[i].size()) {
|
||||
shape_key += ",";
|
||||
}
|
||||
}
|
||||
if (i + 1 < dims.size()) {
|
||||
shape_key += ";";
|
||||
}
|
||||
}
|
||||
return shape_key;
|
||||
}
|
||||
|
||||
std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims) {
|
||||
std::string shape_key = GenerateShapeKey(dims);
|
||||
if (auto iter = dynamic_size_graph_map_.find(shape_key); iter != dynamic_size_graph_map_.end()) {
|
||||
MS_LOG(INFO) << "This options has been built, read cache.";
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
const auto &inputs = func_graph->parameters();
|
||||
if (dims.size() != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match model inputs size " << inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
const auto ¶m = inputs[i];
|
||||
auto shape_ptr = std::dynamic_pointer_cast<abstract::Shape>(param->Shape());
|
||||
if (shape_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Inputs " << i << " is not supported to resize, debug string: " << param->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
shape_ptr->shape() = dims[i];
|
||||
}
|
||||
|
||||
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(func_graph, ModelType::kMindIR));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
return nullptr;
|
||||
}
|
||||
dynamic_size_graph_map_[shape_key] = graph_cell;
|
||||
return graph_cell;
|
||||
}
|
||||
|
||||
Status MsModel::Build() {
|
||||
MS_LOG(INFO) << "Start build model.";
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
|
||||
if (graph_cell_ != nullptr) {
|
||||
MS_LOG(INFO) << "This model has been built, skip.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
|
@ -36,7 +96,7 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
|
|||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
return ret;
|
||||
}
|
||||
|
@ -44,55 +104,66 @@ Status MsModel::Build(const std::map<std::string, std::string> &) {
|
|||
// save result
|
||||
graph_cell_ = graph_cell;
|
||||
MS_LOG(INFO) << "Build model success.";
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
Status MsModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
MS_LOG(INFO) << "Start to resize model";
|
||||
auto origin_inputs = GetInputs();
|
||||
if (inputs.size() != origin_inputs.size()) {
|
||||
MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
if (inputs.size() != dims.size()) {
|
||||
MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
auto graph_cell = GenerateGraphCell(dims);
|
||||
if (graph_cell == nullptr) {
|
||||
MS_LOG(ERROR) << "GenerateGraphCell failed.";
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Resize model success.";
|
||||
graph_cell_ = std::move(graph_cell);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status MsModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
Status MsModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
|
||||
return FAILED;
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
if (graph_cell_ == nullptr) {
|
||||
MS_LOG(INFO) << "Model has not been built, it will be built with default options";
|
||||
Status ret = Build({});
|
||||
if (ret != SUCCESS) {
|
||||
Status ret = Build();
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
Status ret = graph_cell_->Run(inputs, outputs);
|
||||
if (ret != SUCCESS) {
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return FAILED;
|
||||
return ret;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status MsModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> MsModel::GetInputs() {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return graph_cell_->GetInputs();
|
||||
}
|
||||
|
||||
Status MsModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
std::vector<MSTensor> MsModel::GetOutputs() {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
return graph_cell_->GetOutputs();
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,26 +33,24 @@
|
|||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MsModel : public ModelImpl {
|
||||
public:
|
||||
MsModel() {}
|
||||
~MsModel() = default;
|
||||
|
||||
Status Build(const std::map<std::string, std::string> &options_map) override;
|
||||
Status Build() override;
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
std::shared_ptr<GraphCell> graph_cell_;
|
||||
std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "include/api/ops/ops.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
|
||||
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group)
|
||||
: OpCell("Conv2D"),
|
||||
|
@ -35,4 +35,4 @@ Output Conv2D::operator()(const Input &input1, const Input &input2) const {
|
|||
std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) {
|
||||
return {Output(shared_from_this(), inputs, 1)};
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace py = pybind11;
|
|||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
static void RegAllOpFromPython() {
|
||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
|
@ -143,4 +143,4 @@ PythonEnvGuard::~PythonEnvGuard() {
|
|||
FinalizePython();
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
void RegAllOp();
|
||||
bool PythonIsInited();
|
||||
void InitPython();
|
||||
|
@ -30,5 +30,5 @@ class PythonEnvGuard {
|
|||
private:
|
||||
bool origin_init_status_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace mindspore {
|
||||
static Buffer ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
|
@ -68,6 +68,22 @@ static Buffer ReadFile(const std::string &file) {
|
|||
return buffer;
|
||||
}
|
||||
|
||||
Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(EXCEPTION) << "Load MindIR failed.";
|
||||
}
|
||||
|
||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
} else if (model_type == kOM) {
|
||||
return Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
||||
}
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
Buffer data = ReadFile(file);
|
||||
if (data.Data() == nullptr) {
|
||||
|
@ -77,7 +93,7 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
|||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
|
||||
} catch (std::exception &e) {
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(EXCEPTION) << "Load MindIR failed.";
|
||||
}
|
||||
|
||||
|
@ -90,21 +106,21 @@ Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
|||
|
||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
return kMEFailed;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,17 +17,20 @@
|
|||
#include <numeric>
|
||||
#include "securec/include/securec.h"
|
||||
#include "utils/utils.h"
|
||||
#include "mindspore/core/ir/api_tensor_impl.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
const char *kDeviceTypeAscend310 = "Ascend310";
|
||||
const char *kDeviceTypeAscend910 = "Ascend910";
|
||||
const char *kDeviceTypeGpu = "GPU";
|
||||
|
||||
class DataImpl {
|
||||
namespace mindspore {
|
||||
class Buffer::Impl {
|
||||
public:
|
||||
DataImpl() : data_() {}
|
||||
~DataImpl() = default;
|
||||
DataImpl(const void *data, size_t data_len) { SetData(data, data_len); }
|
||||
Impl() : data_() {}
|
||||
~Impl() = default;
|
||||
Impl(const void *data, size_t data_len) {
|
||||
if (data != nullptr) {
|
||||
(void)SetData(data, data_len);
|
||||
} else {
|
||||
ResizeData(data_len);
|
||||
}
|
||||
}
|
||||
|
||||
const void *Data() const { return data_.data(); }
|
||||
void *MutableData() { return data_.data(); }
|
||||
|
@ -66,132 +69,162 @@ class DataImpl {
|
|||
std::vector<uint8_t> data_;
|
||||
};
|
||||
|
||||
class Buffer::Impl : public DataImpl {
|
||||
class TensorDefaultImpl : public MSTensor::Impl {
|
||||
public:
|
||||
Impl() : DataImpl() {}
|
||||
~Impl() = default;
|
||||
Impl(const void *data, size_t data_len) : DataImpl(data, data_len) {}
|
||||
};
|
||||
TensorDefaultImpl() : buffer_(), name_(), type_(DataType::kTypeUnknown), shape_() {}
|
||||
~TensorDefaultImpl() override = default;
|
||||
TensorDefaultImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: buffer_(data, data_len), name_(name), type_(type), shape_(shape) {}
|
||||
|
||||
class Tensor::Impl : public DataImpl {
|
||||
public:
|
||||
Impl() : DataImpl(), name_(), type_(DataType::kMsUnknown), shape_() {}
|
||||
~Impl() = default;
|
||||
Impl(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: DataImpl(data, data_len), name_(name), type_(type), shape_(shape) {}
|
||||
const std::string &Name() const override { return name_; }
|
||||
enum DataType DataType() const override { return type_; }
|
||||
const std::vector<int64_t> &Shape() const override { return shape_; }
|
||||
|
||||
const std::string &Name() const { return name_; }
|
||||
void SetName(const std::string &name) { name_ = name; }
|
||||
|
||||
api::DataType DataType() const { return type_; }
|
||||
void SetDataType(api::DataType type) { type_ = type; }
|
||||
|
||||
void SetShape(const std::vector<int64_t> &shape) { shape_ = shape; }
|
||||
const std::vector<int64_t> &Shape() const { return shape_; }
|
||||
|
||||
int64_t ElementNum() const {
|
||||
std::vector<int64_t> shapex = Shape();
|
||||
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
|
||||
std::shared_ptr<const void> Data() const override {
|
||||
return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
|
||||
}
|
||||
|
||||
static int GetTypeSize(api::DataType type) {
|
||||
static const std::map<api::DataType, size_t> type_size_map = {
|
||||
{kMsBool, sizeof(bool)}, {kMsFloat64, sizeof(double)}, {kMsInt8, sizeof(int8_t)},
|
||||
{kMsUint8, sizeof(uint8_t)}, {kMsInt16, sizeof(int16_t)}, {kMsUint16, sizeof(uint16_t)},
|
||||
{kMsInt32, sizeof(int32_t)}, {kMsUint32, sizeof(uint32_t)}, {kMsInt64, sizeof(int64_t)},
|
||||
{kMsUint64, sizeof(uint64_t)}, {kMsFloat16, sizeof(uint16_t)}, {kMsFloat32, sizeof(float)},
|
||||
};
|
||||
auto it = type_size_map.find(type);
|
||||
if (it != type_size_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
void *MutableData() override { return buffer_.MutableData(); }
|
||||
size_t DataSize() const override { return buffer_.DataSize(); }
|
||||
|
||||
MS_LOG(WARNING) << "Cannot find data type " << type;
|
||||
return 0;
|
||||
bool IsDevice() const override { return false; }
|
||||
|
||||
std::shared_ptr<Impl> Clone() const override {
|
||||
return std::make_shared<TensorDefaultImpl>(name_, type_, shape_, buffer_.Data(), buffer_.DataSize());
|
||||
}
|
||||
|
||||
private:
|
||||
Buffer buffer_;
|
||||
std::string name_;
|
||||
api::DataType type_;
|
||||
enum DataType type_;
|
||||
std::vector<int64_t> shape_;
|
||||
};
|
||||
|
||||
Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
||||
Tensor::Tensor(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {}
|
||||
Tensor::~Tensor() = default;
|
||||
class TensorReferenceImpl : public MSTensor::Impl {
|
||||
public:
|
||||
TensorReferenceImpl() : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_() {}
|
||||
~TensorReferenceImpl() override = default;
|
||||
TensorReferenceImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape) {}
|
||||
|
||||
Tensor Tensor::Clone() const {
|
||||
const std::string &Name() const override { return name_; }
|
||||
enum DataType DataType() const override { return type_; }
|
||||
const std::vector<int64_t> &Shape() const override { return shape_; }
|
||||
|
||||
std::shared_ptr<const void> Data() const override {
|
||||
return std::shared_ptr<const void>(data_, [](const void *) {});
|
||||
}
|
||||
|
||||
void *MutableData() override { return const_cast<void *>(data_); }
|
||||
size_t DataSize() const override { return data_size_; }
|
||||
|
||||
bool IsDevice() const override { return false; }
|
||||
|
||||
std::shared_ptr<Impl> Clone() const override {
|
||||
return std::make_shared<TensorReferenceImpl>(name_, type_, shape_, data_, data_size_);
|
||||
}
|
||||
|
||||
protected:
|
||||
const void *data_;
|
||||
size_t data_size_;
|
||||
std::string name_;
|
||||
enum DataType type_;
|
||||
std::vector<int64_t> shape_;
|
||||
};
|
||||
|
||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
try {
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
|
||||
return MSTensor(impl);
|
||||
} catch (const std::bad_alloc &) {
|
||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||
return MSTensor(nullptr);
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Unknown error occurred.";
|
||||
return MSTensor(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
try {
|
||||
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
|
||||
return MSTensor(impl);
|
||||
} catch (const std::bad_alloc &) {
|
||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||
return MSTensor(nullptr);
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Unknown error occurred.";
|
||||
return MSTensor(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
|
||||
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
|
||||
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
|
||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
|
||||
MSTensor::~MSTensor() = default;
|
||||
|
||||
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
||||
|
||||
MSTensor MSTensor::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
Tensor ret;
|
||||
ret.impl_ = std::make_shared<Impl>(*impl_);
|
||||
MSTensor ret;
|
||||
ret.impl_ = impl_->Clone();
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::string &Tensor::Name() const {
|
||||
const std::string &MSTensor::Name() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Name();
|
||||
}
|
||||
|
||||
void Tensor::SetName(const std::string &name) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetName(name);
|
||||
}
|
||||
|
||||
DataType Tensor::DataType() const {
|
||||
enum DataType MSTensor::DataType() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->DataType();
|
||||
}
|
||||
|
||||
void Tensor::SetDataType(api::DataType type) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDataType(type);
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &Tensor::Shape() const {
|
||||
const std::vector<int64_t> &MSTensor::Shape() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Shape();
|
||||
}
|
||||
|
||||
void Tensor::SetShape(const std::vector<int64_t> &shape) {
|
||||
int64_t MSTensor::ElementNum() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetShape(shape);
|
||||
const auto &shape = impl_->Shape();
|
||||
if (shape.empty()) {
|
||||
// element number of scalar is 1
|
||||
return 1;
|
||||
}
|
||||
|
||||
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
const void *Tensor::Data() const {
|
||||
std::shared_ptr<const void> MSTensor::Data() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Data();
|
||||
}
|
||||
|
||||
void *Tensor::MutableData() {
|
||||
void *MSTensor::MutableData() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->MutableData();
|
||||
}
|
||||
|
||||
size_t Tensor::DataSize() const {
|
||||
size_t MSTensor::DataSize() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->DataSize();
|
||||
}
|
||||
|
||||
bool Tensor::ResizeData(size_t data_len) {
|
||||
bool MSTensor::IsDevice() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->ResizeData(data_len);
|
||||
return impl_->IsDevice();
|
||||
}
|
||||
|
||||
bool Tensor::SetData(const void *data, size_t data_len) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->SetData(data, data_len);
|
||||
}
|
||||
|
||||
int64_t Tensor::ElementNum() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->ElementNum();
|
||||
}
|
||||
|
||||
int Tensor::GetTypeSize(api::DataType type) { return Impl::GetTypeSize(type); }
|
||||
|
||||
Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
|
||||
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
|
||||
Buffer::~Buffer() = default;
|
||||
|
@ -227,4 +260,4 @@ bool Buffer::SetData(const void *data, size_t data_len) {
|
|||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->SetData(data, data_len);
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -284,14 +284,7 @@ else()
|
|||
endif()
|
||||
|
||||
add_dependencies(_c_dataengine mindspore_shared_lib)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||
set(MINDSPORE_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/cxx_api/CMakeFiles/mindspore_shared_lib.dir/objects.a)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib ${MINDSPORE_LINK_OBJECT})
|
||||
else()
|
||||
if(ENABLE_ACL)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore_shared_lib)
|
||||
|
||||
if(USE_GLOG)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::glog)
|
||||
|
|
|
@ -26,28 +26,13 @@ if(ENABLE_PYTHON)
|
|||
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
|
||||
if(ENABLE_ACL)
|
||||
add_library(cpp-API OBJECT
|
||||
config.cc
|
||||
datasets.cc
|
||||
execute.cc
|
||||
iterator.cc
|
||||
minddata_eager.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
text.cc
|
||||
vision.cc
|
||||
)
|
||||
else()
|
||||
add_library(cpp-API OBJECT
|
||||
config.cc
|
||||
datasets.cc
|
||||
execute.cc
|
||||
iterator.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
text.cc
|
||||
vision.cc
|
||||
)
|
||||
endif()
|
||||
add_library(cpp-API OBJECT
|
||||
config.cc
|
||||
datasets.cc
|
||||
execute.cc
|
||||
iterator.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
text.cc
|
||||
vision.cc
|
||||
)
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/include/de_tensor.h"
|
||||
#include "minddata/dataset/include/type_id.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "mindspore/lite/include/ms_tensor.h"
|
||||
#include "utils/hashing.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "utils/log_adapter.h"
|
||||
#else
|
||||
#include "mindspore/lite/src/common/log_adapter.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
MSTensor *DETensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
return new DETensor(data_type, shape);
|
||||
}
|
||||
|
||||
MSTensor *DETensor::CreateTensor(const std::string &path) {
|
||||
std::shared_ptr<dataset::Tensor> t;
|
||||
(void)dataset::Tensor::CreateFromFile(path, &t);
|
||||
return new DETensor(std::move(t));
|
||||
}
|
||||
|
||||
MSTensor *DETensor::CreateFromMemory(TypeId data_type, const std::vector<int> &shape, void *data) {
|
||||
std::shared_ptr<dataset::Tensor> t;
|
||||
// prepare shape info
|
||||
std::vector<dataset::dsize_t> t_shape;
|
||||
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
|
||||
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
|
||||
|
||||
(void)dataset::Tensor::CreateFromMemory(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type),
|
||||
static_cast<uint8_t *>(data), &t);
|
||||
return new DETensor(std::move(t));
|
||||
}
|
||||
|
||||
DETensor::DETensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
std::vector<dataset::dsize_t> t_shape;
|
||||
t_shape.reserve(shape.size());
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
|
||||
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
|
||||
dataset::Tensor::CreateEmpty(dataset::TensorShape(t_shape), dataset::MSTypeToDEType(data_type), &this->tensor_impl_);
|
||||
}
|
||||
|
||||
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
|
||||
|
||||
MSTensor *DETensor::ConvertToLiteTensor() {
|
||||
// static MSTensor::CreateTensor is only for the LiteTensor
|
||||
MSTensor *tensor = CreateTensor(this->data_type(), this->shape());
|
||||
MS_ASSERT(tensor->Size() == this->Size());
|
||||
memcpy_s(tensor->MutableData(), tensor->Size(), this->MutableData(), this->Size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
std::shared_ptr<dataset::Tensor> DETensor::tensor() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_;
|
||||
}
|
||||
|
||||
TypeId DETensor::data_type() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return dataset::DETypeToMSType(this->tensor_impl_->type());
|
||||
}
|
||||
|
||||
TypeId DETensor::set_data_type(TypeId data_type) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
if (data_type != this->data_type()) {
|
||||
std::shared_ptr<dataset::Tensor> temp;
|
||||
dataset::Tensor::CreateFromMemory(this->tensor_impl_->shape(), dataset::MSTypeToDEType(data_type),
|
||||
this->tensor_impl_->GetBuffer(), &temp);
|
||||
this->tensor_impl_ = temp;
|
||||
}
|
||||
return data_type;
|
||||
}
|
||||
|
||||
std::vector<int> DETensor::shape() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
std::vector<dataset::dsize_t> t_shape = this->tensor_impl_->shape().AsVector();
|
||||
std::vector<int> shape;
|
||||
shape.reserve(t_shape.size());
|
||||
std::transform(t_shape.begin(), t_shape.end(), std::back_inserter(shape),
|
||||
[](dataset::dsize_t s) -> int { return static_cast<int>(s); });
|
||||
return shape;
|
||||
}
|
||||
|
||||
size_t DETensor::set_shape(const std::vector<int> &shape) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
std::vector<dataset::dsize_t> t_shape;
|
||||
t_shape.reserve(shape.size());
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(t_shape),
|
||||
[](int s) -> dataset::dsize_t { return static_cast<dataset::dsize_t>(s); });
|
||||
dataset::Status rc = this->tensor_impl_->Reshape(dataset::TensorShape(t_shape));
|
||||
return shape.size();
|
||||
}
|
||||
|
||||
int DETensor::DimensionSize(size_t index) const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
int dim_size = -1;
|
||||
auto shape = this->shape();
|
||||
if (index < shape.size()) {
|
||||
dim_size = shape[index];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Dimension index is wrong: " << index;
|
||||
}
|
||||
return dim_size;
|
||||
}
|
||||
|
||||
int DETensor::ElementsNum() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->Size();
|
||||
}
|
||||
|
||||
size_t DETensor::Size() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->SizeInBytes();
|
||||
}
|
||||
|
||||
void *DETensor::MutableData() {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->GetMutableBuffer();
|
||||
}
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
|
@ -14,12 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#ifdef ENABLE_ANDROID
|
||||
#include "minddata/dataset/include/de_tensor.h"
|
||||
#endif
|
||||
#include "minddata/dataset/include/execute.h"
|
||||
#include "minddata/dataset/core/de_tensor.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/include/tensor.h"
|
||||
#include "minddata/dataset/include/type_id.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -30,78 +29,85 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}
|
||||
Execute::Execute(std::shared_ptr<TensorOperation> op) { ops_.emplace_back(std::move(op)); }
|
||||
|
||||
/// \brief Destructor
|
||||
Execute::~Execute() = default;
|
||||
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops) : ops_(std::move(ops)) {}
|
||||
|
||||
#ifdef ENABLE_ANDROID
|
||||
std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) {
|
||||
// Build the op
|
||||
if (op_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Input TensorOperation is not valid";
|
||||
return nullptr;
|
||||
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
|
||||
// Validate input tensor
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
|
||||
|
||||
// Validate and build runtime ops
|
||||
std::vector<std::shared_ptr<TensorOp>> transforms;
|
||||
for (int32_t i = 0; i < ops_.size(); i++) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
|
||||
RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
|
||||
transforms.emplace_back(ops_[i]->Build());
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> de_input = std::dynamic_pointer_cast<tensor::DETensor>(input)->tensor();
|
||||
if (de_input == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Tensor is not valid";
|
||||
return nullptr;
|
||||
}
|
||||
std::shared_ptr<TensorOp> transform = op_->Build();
|
||||
std::shared_ptr<Tensor> de_output;
|
||||
Status rc = transform->Compute(de_input, &de_output);
|
||||
// Convert mindspore::Tensor to dataset::Tensor
|
||||
std::shared_ptr<dataset::Tensor> de_tensor;
|
||||
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()),
|
||||
MSTypeToDEType(static_cast<TypeId>(input.DataType())),
|
||||
(const uchar *)(input.Data().get()), input.DataSize(), &de_tensor);
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
||||
if (rc.IsError()) {
|
||||
// execution failed
|
||||
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<tensor::DETensor>(std::move(de_output));
|
||||
}
|
||||
#endif
|
||||
// Apply transforms on tensor
|
||||
for (auto &t : transforms) {
|
||||
std::shared_ptr<dataset::Tensor> de_output;
|
||||
RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output));
|
||||
|
||||
std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) {
|
||||
// Build the op
|
||||
if (op_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Input TensorOperation is not valid";
|
||||
return nullptr;
|
||||
// For next transform
|
||||
de_tensor = std::move(de_output);
|
||||
}
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Tensor is not valid";
|
||||
return nullptr;
|
||||
}
|
||||
// will add validate params once API is set
|
||||
std::shared_ptr<TensorOp> transform = op_->Build();
|
||||
std::shared_ptr<Tensor> de_output;
|
||||
Status rc = transform->Compute(input, &de_output);
|
||||
|
||||
if (rc.IsError()) {
|
||||
// execution failed
|
||||
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
return de_output;
|
||||
// Convert dataset::Tensor to mindspore::Tensor
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
|
||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
|
||||
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
|
||||
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
|
||||
// Validate input tensor
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
|
||||
for (auto &tensor : input_tensor_list) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
|
||||
|
||||
TensorRow input, output;
|
||||
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid");
|
||||
|
||||
std::shared_ptr<TensorOp> transform = op_->Build();
|
||||
Status rc = transform->Compute(input, &output);
|
||||
if (rc.IsError()) {
|
||||
// execution failed
|
||||
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
|
||||
// Validate and build runtime ops
|
||||
std::vector<std::shared_ptr<TensorOp>> transforms;
|
||||
for (int32_t i = 0; i < ops_.size(); i++) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
|
||||
RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
|
||||
transforms.emplace_back(ops_[i]->Build());
|
||||
}
|
||||
|
||||
std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list));
|
||||
TensorRow de_tensor_list;
|
||||
for (auto &tensor : input_tensor_list) {
|
||||
std::shared_ptr<dataset::Tensor> de_tensor;
|
||||
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()),
|
||||
MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
|
||||
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor);
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
de_tensor_list.emplace_back(std::move(de_tensor));
|
||||
}
|
||||
|
||||
// Apply transforms on tensor
|
||||
for (auto &t : transforms) {
|
||||
TensorRow de_output_list;
|
||||
RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list));
|
||||
// For next transform
|
||||
de_tensor_list = std::move(de_output_list);
|
||||
}
|
||||
|
||||
for (auto &tensor : de_tensor_list) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tensor->HasData(), "Apply transform failed, output tensor has no data");
|
||||
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
|
||||
output_tensor_list->emplace_back(ms_tensor);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,154 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "minddata/dataset/include/minddata_eager.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
MindDataEager::MindDataEager(std::vector<std::shared_ptr<dataset::TensorOperation>> ops) : ops_(ops) {}
|
||||
|
||||
// Helper function to convert Type from DE to MS
|
||||
DataType ToMSType(dataset::DataType type) {
|
||||
switch (dataset::DataType::Type(type)) {
|
||||
case dataset::DataType::DE_BOOL:
|
||||
return DataType::kMsBool;
|
||||
case dataset::DataType::DE_UINT8:
|
||||
return DataType::kMsUint8;
|
||||
case dataset::DataType::DE_INT32:
|
||||
return DataType::kMsInt32;
|
||||
case dataset::DataType::DE_INT64:
|
||||
return DataType::kMsInt64;
|
||||
case dataset::DataType::DE_FLOAT32:
|
||||
return DataType::kMsFloat32;
|
||||
default:
|
||||
return DataType::kMsUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to convert Type from MS to DE
|
||||
dataset::DataType ToDEType(DataType type) {
|
||||
switch (type) {
|
||||
case DataType::kMsBool:
|
||||
return dataset::DataType(dataset::DataType::DE_BOOL);
|
||||
case DataType::kMsUint8:
|
||||
return dataset::DataType(dataset::DataType::DE_UINT8);
|
||||
case DataType::kMsInt32:
|
||||
return dataset::DataType(dataset::DataType::DE_INT32);
|
||||
case DataType::kMsInt64:
|
||||
return dataset::DataType(dataset::DataType::DE_INT64);
|
||||
case DataType::kMsFloat32:
|
||||
return dataset::DataType(dataset::DataType::DE_FLOAT32);
|
||||
default:
|
||||
return dataset::DataType(dataset::DataType::DE_UNKNOWN);
|
||||
}
|
||||
}
|
||||
|
||||
Status MindDataEager::LoadImageFromDir(const std::string &image_dir, std::vector<std::shared_ptr<Tensor>> *images) {
|
||||
// Check target directory
|
||||
dataset::Path image_dir_(image_dir);
|
||||
if (!image_dir_.Exists() || !image_dir_.IsDirectory()) {
|
||||
std::string err_msg = "Target directory: " + image_dir + " does not exist or not a directory.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::FAILED, err_msg);
|
||||
}
|
||||
if (access(image_dir_.toString().c_str(), R_OK) == -1) {
|
||||
std::string err_msg = "No access to target directory: " + image_dir;
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::FAILED, err_msg);
|
||||
}
|
||||
|
||||
// Start reading images and constructing tensors
|
||||
auto path_itr = dataset::Path::DirIterator::OpenDirectory(&image_dir_);
|
||||
while (path_itr->hasNext()) {
|
||||
dataset::Path file = path_itr->next();
|
||||
std::shared_ptr<dataset::Tensor> image;
|
||||
dataset::Tensor::CreateFromFile(file.toString(), &image);
|
||||
|
||||
std::shared_ptr<Tensor> ms_image = std::make_shared<Tensor>("image", DataType(kMsUint8), image->shape().AsVector(),
|
||||
image->GetBuffer(), image->SizeInBytes());
|
||||
images->push_back(ms_image);
|
||||
}
|
||||
|
||||
// Check if read images or not
|
||||
if (images->empty()) {
|
||||
std::string err_msg = "No images found in target directory: " + image_dir;
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::FAILED, err_msg);
|
||||
}
|
||||
|
||||
return Status(StatusCode::SUCCESS);
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> MindDataEager::operator()(std::shared_ptr<Tensor> input) {
|
||||
// Validate ops
|
||||
if (ops_.empty()) {
|
||||
MS_LOG(ERROR) << "Input TensorOperation should be provided";
|
||||
return nullptr;
|
||||
}
|
||||
for (int32_t i = 0; i < ops_.size(); i++) {
|
||||
if (ops_[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "Input TensorOperation[" << i << "] is invalid or null";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// Validate input tensor
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Tensor should not be null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Start applying transforms in ops
|
||||
std::shared_ptr<dataset::Tensor> de_input;
|
||||
dataset::Tensor::CreateFromMemory(dataset::TensorShape(input->Shape()), ToDEType(input->DataType()),
|
||||
(const uchar *)(input->Data()), &de_input);
|
||||
|
||||
for (int32_t i = 0; i < ops_.size(); i++) {
|
||||
// Build runtime op and run
|
||||
std::shared_ptr<dataset::Tensor> de_output;
|
||||
std::shared_ptr<dataset::TensorOp> transform = ops_[i]->Build();
|
||||
dataset::Status rc = transform->Compute(de_input, &de_output);
|
||||
|
||||
// check execution failed
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For next transform
|
||||
de_input = std::move(de_output);
|
||||
}
|
||||
|
||||
// Convert DETensor to Tensor
|
||||
if (!de_input->HasData()) {
|
||||
MS_LOG(ERROR) << "Apply transform failed, output tensor has no data";
|
||||
return nullptr;
|
||||
}
|
||||
std::shared_ptr<Tensor> output =
|
||||
std::make_shared<Tensor>("transfomed", ToMSType(de_input->type()), de_input->shape().AsVector(),
|
||||
de_input->GetBuffer(), de_input->SizeInBytes());
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
|
@ -29,25 +29,42 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
|
|||
return execute;
|
||||
}))
|
||||
.def("__call__",
|
||||
[](Execute &self, std::shared_ptr<Tensor> in) {
|
||||
std::shared_ptr<Tensor> out = self(in);
|
||||
if (out == nullptr) {
|
||||
THROW_IF_ERROR([]() {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Failed to execute op in eager mode, please check ERROR log above.");
|
||||
[](Execute &self, const std::shared_ptr<Tensor> &de_tensor) {
|
||||
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
|
||||
Status rc = self(ms_tensor, &ms_tensor);
|
||||
if (rc.IsError()) {
|
||||
THROW_IF_ERROR([&rc]() {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString());
|
||||
}());
|
||||
}
|
||||
return out;
|
||||
std::shared_ptr<dataset::Tensor> de_output_tensor;
|
||||
dataset::Tensor::CreateFromMemory(dataset::TensorShape(ms_tensor.Shape()),
|
||||
MSTypeToDEType(static_cast<TypeId>(ms_tensor.DataType())),
|
||||
(const uchar *)(ms_tensor.Data().get()),
|
||||
ms_tensor.DataSize(), &de_output_tensor);
|
||||
return de_output_tensor;
|
||||
})
|
||||
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
|
||||
std::vector<std::shared_ptr<Tensor>> output_tensor_list;
|
||||
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list));
|
||||
if (output_tensor_list.empty()) {
|
||||
THROW_IF_ERROR([]() {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
|
||||
}());
|
||||
std::vector<MSTensor> ms_input_tensor_list;
|
||||
std::vector<MSTensor> ms_output_tensor_list;
|
||||
for (auto &tensor : input_tensor_list) {
|
||||
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
|
||||
ms_input_tensor_list.emplace_back(std::move(ms_tensor));
|
||||
}
|
||||
return output_tensor_list;
|
||||
Status rc = self(ms_input_tensor_list, &ms_output_tensor_list);
|
||||
if (rc.IsError()) {
|
||||
THROW_IF_ERROR(
|
||||
[&rc]() { RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); }());
|
||||
}
|
||||
std::vector<std::shared_ptr<dataset::Tensor>> de_output_tensor_list;
|
||||
for (auto &tensor : ms_output_tensor_list) {
|
||||
std::shared_ptr<dataset::Tensor> de_output_tensor;
|
||||
dataset::Tensor::CreateFromMemory(
|
||||
dataset::TensorShape(tensor.Shape()), MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
|
||||
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_output_tensor);
|
||||
de_output_tensor_list.emplace_back(std::move(de_output_tensor));
|
||||
}
|
||||
return de_output_tensor_list;
|
||||
});
|
||||
}));
|
||||
} // namespace dataset
|
||||
|
|
|
@ -84,7 +84,8 @@ PYBIND_REGISTER(SliceOption, 0, ([](const py::module *m) {
|
|||
}
|
||||
|
||||
if (!c_slice.valid()) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
THROW_IF_ERROR(
|
||||
Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
}
|
||||
return SliceOption(c_slice);
|
||||
}))
|
||||
|
|
|
@ -354,7 +354,7 @@ PYBIND_REGISTER(
|
|||
for (auto handle : py_sub.cast<py::list>()) {
|
||||
py::tuple tp = handle.cast<py::tuple>();
|
||||
if (tp.is_none() || tp.size() != 2) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
|
||||
THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
|
||||
}
|
||||
std::shared_ptr<TensorOperation> t_op;
|
||||
if (py::isinstance<TensorOperation>(tp[0])) {
|
||||
|
@ -366,11 +366,11 @@ PYBIND_REGISTER(
|
|||
std::make_shared<PyFuncOp>((tp[0]).cast<py::function>()));
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
Status(StatusCode::kUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
|
||||
Status(StatusCode::kMDUnexpectedError, "op is neither a tensorOp, tensorOperation nor a pyfunc."));
|
||||
}
|
||||
double prob = (tp[1]).cast<py::float_>();
|
||||
if (prob < 0 || prob > 1) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
|
||||
THROW_IF_ERROR(Status(StatusCode::kMDUnexpectedError, "prob needs to be with [0,1]."));
|
||||
}
|
||||
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
|
||||
}
|
||||
|
|
|
@ -51,12 +51,12 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
|
|||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
f(cb_param);
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
return Status(StatusCode::kMDPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -5,6 +5,7 @@ set(DATASET_CORE_SRC_FILES
|
|||
config_manager.cc
|
||||
cv_tensor.cc
|
||||
data_type.cc
|
||||
de_tensor.cc
|
||||
global_context.cc
|
||||
tensor.cc
|
||||
tensor_helpers.cc
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/core/de_tensor.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/include/type_id.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "utils/hashing.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "utils/log_adapter.h"
|
||||
#define ASSERT_NULL(ptr) MS_EXCEPTION_IF_NULL(ptr)
|
||||
#else
|
||||
#include "mindspore/lite/src/common/log_adapter.h"
|
||||
#define ASSERT_NULL(ptr) MS_ASSERT((ptr) != nullptr)
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
DETensor::DETensor(std::shared_ptr<dataset::Tensor> tensor_impl)
|
||||
: tensor_impl_(tensor_impl),
|
||||
name_("MindDataTensor"),
|
||||
type_(static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()))),
|
||||
shape_(tensor_impl_->shape().AsVector()) {}
|
||||
|
||||
const std::string &DETensor::Name() const { return name_; }
|
||||
|
||||
enum mindspore::DataType DETensor::DataType() const {
|
||||
ASSERT_NULL(tensor_impl_);
|
||||
return static_cast<mindspore::DataType>(DETypeToMSType(tensor_impl_->type()));
|
||||
}
|
||||
|
||||
size_t DETensor::DataSize() const {
|
||||
ASSERT_NULL(tensor_impl_);
|
||||
return tensor_impl_->SizeInBytes();
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &DETensor::Shape() const { return shape_; }
|
||||
|
||||
std::shared_ptr<const void> DETensor::Data() const {
|
||||
return std::shared_ptr<const void>(tensor_impl_->GetBuffer(), [](const void *) {});
|
||||
}
|
||||
|
||||
void *DETensor::MutableData() {
|
||||
ASSERT_NULL(tensor_impl_);
|
||||
return tensor_impl_->GetMutableBuffer();
|
||||
}
|
||||
|
||||
bool DETensor::IsDevice() const { return false; }
|
||||
|
||||
std::shared_ptr<mindspore::MSTensor::Impl> DETensor::Clone() const { return std::make_shared<DETensor>(tensor_impl_); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "mindspore/core/ir/api_tensor_impl.h"
|
||||
#include "minddata/dataset/include/status.h"
|
||||
#include "minddata/dataset/include/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DETensor : public mindspore::MSTensor::Impl {
|
||||
public:
|
||||
DETensor() = default;
|
||||
~DETensor() override = default;
|
||||
explicit DETensor(std::shared_ptr<dataset::Tensor> tensor_impl);
|
||||
|
||||
const std::string &Name() const override;
|
||||
|
||||
enum mindspore::DataType DataType() const override;
|
||||
|
||||
size_t DataSize() const override;
|
||||
|
||||
const std::vector<int64_t> &Shape() const override;
|
||||
|
||||
std::shared_ptr<const void> Data() const override;
|
||||
|
||||
void *MutableData() override;
|
||||
|
||||
bool IsDevice() const override;
|
||||
|
||||
std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<dataset::Tensor> tensor_impl_;
|
||||
std::string name_;
|
||||
enum mindspore::DataType type_;
|
||||
std::vector<int64_t> shape_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DETENSOR_H_
|
|
@ -41,23 +41,17 @@
|
|||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/core/tensor_helpers.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/core/de_tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "proto/example.pb.h"
|
||||
#else
|
||||
#include "minddata/dataset/include/de_tensor.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
namespace py = pybind11;
|
||||
#endif
|
||||
namespace mindspore {
|
||||
#ifdef ENABLE_ANDROID
|
||||
namespace tensor {
|
||||
class DETensor;
|
||||
} // namespace tensor
|
||||
#endif
|
||||
namespace dataset {
|
||||
class Tensor;
|
||||
template <typename T>
|
||||
|
@ -85,7 +79,7 @@ class Tensor {
|
|||
/// \param other Tensor to be moved
|
||||
Tensor(Tensor &&other) noexcept;
|
||||
|
||||
/// Move assigment operator
|
||||
/// Move assignment operator
|
||||
/// \param other Tensor to be moved
|
||||
Tensor &operator=(Tensor &&other) noexcept;
|
||||
|
||||
|
@ -134,7 +128,7 @@ class Tensor {
|
|||
#ifndef ENABLE_ANDROID
|
||||
/// Create a tensor of type DE_STRING from a BytesList.
|
||||
/// \param[in] bytes_list protobuf's Bytelist
|
||||
/// \param[in] shape shape of the outout tensor
|
||||
/// \param[in] shape shape of the output tensor
|
||||
/// \param[out] out created Tensor
|
||||
/// \return Status Code
|
||||
static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out);
|
||||
|
@ -292,7 +286,7 @@ class Tensor {
|
|||
std::string err;
|
||||
err += (data_ == nullptr) ? "data_ is nullptr \t" : "";
|
||||
err += type_.IsCompatible<T>() ? "data type not compatible\t" : "";
|
||||
return Status(StatusCode::kUnexpectedError, err);
|
||||
return Status(StatusCode::kMDUnexpectedError, err);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,7 +337,7 @@ class Tensor {
|
|||
void Invalidate();
|
||||
|
||||
/// Copy input tensor into self at the location index.
|
||||
/// Index is a vector of axises which can be incomplete:
|
||||
/// Index is a vector of axes which can be incomplete:
|
||||
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
|
||||
/// \param index
|
||||
/// \param input
|
||||
|
@ -686,9 +680,7 @@ class Tensor {
|
|||
unsigned char *data_end_ = nullptr;
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_ANDROID
|
||||
friend class tensor::DETensor;
|
||||
#endif
|
||||
friend class DETensor;
|
||||
|
||||
/// Slice numeric tensors.
|
||||
Status SliceNumeric(TensorPtr *out, const std::vector<std::vector<dsize_t>> &indices, const TensorShape &shape);
|
||||
|
|
|
@ -73,6 +73,7 @@ if(ENABLE_CACHE)
|
|||
engine-cache-server
|
||||
_c_dataengine
|
||||
_c_mindrecord
|
||||
mindspore
|
||||
mindspore::protobuf
|
||||
mindspore::grpc++
|
||||
mindspore_gvar
|
||||
|
@ -85,6 +86,7 @@ if(ENABLE_CACHE)
|
|||
engine-cache-server
|
||||
_c_dataengine
|
||||
_c_mindrecord
|
||||
mindspore
|
||||
mindspore::protobuf
|
||||
mindspore::grpc++
|
||||
mindspore_gvar
|
||||
|
@ -103,6 +105,7 @@ if(ENABLE_CACHE)
|
|||
|
||||
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
|
||||
target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread)
|
||||
target_link_libraries(cache_admin mindspore mindspore_shared_lib)
|
||||
|
||||
if(USE_GLOG)
|
||||
target_link_libraries(cache_admin mindspore::glog)
|
||||
|
|
|
@ -22,10 +22,11 @@
|
|||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
|
||||
namespace ms = mindspore;
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ms::Status rc;
|
||||
ds::CacheAdminArgHandler args;
|
||||
std::stringstream arg_stream;
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
|
|||
ArgValue selected_arg = arg_map_[option];
|
||||
if (used_args_[selected_arg]) {
|
||||
std::string err_msg = "The " + option + " argument was given more than once.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Flag that this arg is used now
|
||||
|
@ -101,7 +101,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
|
|||
if (command_id != CommandId::kCmdUnknown) {
|
||||
if (command_id_ != CommandId::kCmdUnknown) {
|
||||
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
} else {
|
||||
command_id_ = command_id;
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
|
|||
*arg_stream >> value_as_string;
|
||||
if (value_as_string.empty()) {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Now, attempt to convert the value into it's numeric format for output
|
||||
|
@ -121,7 +121,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
|
|||
*out_arg = std::stoul(value_as_string);
|
||||
} catch (const std::exception &e) {
|
||||
std::string err_msg = "Invalid numeric value: " + value_as_string;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -133,7 +133,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
|
|||
ArgValue selected_arg = arg_map_[option];
|
||||
if (used_args_[selected_arg]) {
|
||||
std::string err_msg = "The " + option + " argument was given more than once.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Flag that this arg is used now
|
||||
|
@ -145,7 +145,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
|
|||
if (command_id != CommandId::kCmdUnknown) {
|
||||
if (command_id_ != CommandId::kCmdUnknown) {
|
||||
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
} else {
|
||||
command_id_ = command_id;
|
||||
}
|
||||
|
@ -158,12 +158,12 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
|
|||
*arg_stream >> *out_arg;
|
||||
} else {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
if (out_arg->empty()) {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,7 +176,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
|
|||
ArgValue selected_arg = arg_map_[option];
|
||||
if (used_args_[selected_arg]) {
|
||||
std::string err_msg = "The " + option + " argument was given more than once.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Flag that this arg is used now
|
||||
|
@ -188,7 +188,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
|
|||
if (command_id != CommandId::kCmdUnknown) {
|
||||
if (command_id_ != CommandId::kCmdUnknown) {
|
||||
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
} else {
|
||||
command_id_ = command_id;
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
|
|||
*arg_stream >> value_as_string;
|
||||
if (value_as_string.empty()) {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Now, attempt to convert the value into it's string format for output
|
||||
|
@ -208,7 +208,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::
|
|||
*out_arg = std::stof(value_as_string, nullptr);
|
||||
} catch (const std::exception &e) {
|
||||
std::string err_msg = "Invalid numeric value: " + value_as_string;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -224,7 +224,7 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
|
|||
if (hostname_ != std::string(kCfgDefaultCacheHost)) {
|
||||
std::string err_msg =
|
||||
"Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ Status CacheAdminArgHandler::Validate() {
|
|||
if (!trailing_args_.empty()) {
|
||||
std::string err_msg = "Invalid arguments provided: " + trailing_args_;
|
||||
err_msg += "\nPlease try `cache_admin --help` for more information";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
|
||||
|
@ -312,18 +312,18 @@ Status CacheAdminArgHandler::Validate() {
|
|||
if (command_id_ == CommandId::kCmdUnknown) {
|
||||
std::string err_msg = "No command provided";
|
||||
err_msg += "\nPlease try `cache_admin --help` for more information";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Additional checks here
|
||||
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
|
||||
if (num_workers_ < 1 || num_workers_ > max_num_workers)
|
||||
return Status(StatusCode::kSyntaxError,
|
||||
return Status(StatusCode::kMDSyntaxError,
|
||||
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
|
||||
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
|
||||
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..3).");
|
||||
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
|
||||
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
|
||||
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535).");
|
||||
return Status(StatusCode::kMDSyntaxError, "Memory cap ratio should be positive and no greater than 1");
|
||||
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535).");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -467,9 +467,9 @@ Status CacheAdminArgHandler::StopServer(CommandId command_id) {
|
|||
Status rc = rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
msg.RemoveResourcesOnExit();
|
||||
if (rc.IsNetWorkError()) {
|
||||
if (rc == StatusCode::kMDNetWorkError) {
|
||||
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
|
||||
return Status(StatusCode::kNetWorkError, errMsg);
|
||||
return Status(StatusCode::kMDNetWorkError, errMsg);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
@ -544,7 +544,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
|
|||
if (WIFEXITED(status)) {
|
||||
auto exit_status = WEXITSTATUS(status);
|
||||
if (exit_status) {
|
||||
return Status(StatusCode::kUnexpectedError, msg);
|
||||
return Status(StatusCode::kMDUnexpectedError, msg);
|
||||
} else {
|
||||
// Not an error, some info message goes to stdout
|
||||
std::cout << msg << std::endl;
|
||||
|
|
|
@ -75,7 +75,7 @@ Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, vo
|
|||
do {
|
||||
std::unique_lock<std::mutex> lock(mux_[slot]);
|
||||
rc = shm_pool_[slot]->Allocate(sz, p);
|
||||
if (rc.IsOutofMemory()) {
|
||||
if (rc == StatusCode::kMDOutOfMemory) {
|
||||
slot = (slot + 1) % shm_pool_.size();
|
||||
}
|
||||
} while (rc.IsError() && slot != begin_slot);
|
||||
|
|
|
@ -137,7 +137,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
|
|||
|
||||
Status CacheClient::AsyncWriteRow(const TensorRow &row) {
|
||||
if (async_buffer_stream_ == nullptr) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
return Status(StatusCode::kMDNotImplementedYet);
|
||||
}
|
||||
RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
|
||||
return Status::OK();
|
||||
|
@ -145,7 +145,7 @@ Status CacheClient::AsyncWriteRow(const TensorRow &row) {
|
|||
|
||||
Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
|
||||
if (async_buffer_stream_ == nullptr) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
return Status(StatusCode::kMDNotImplementedYet);
|
||||
} else {
|
||||
Status rc;
|
||||
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
|
||||
|
@ -155,7 +155,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
|
|||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(in->PopRow(&row));
|
||||
rc = AsyncWriteRow(row);
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
|
||||
tensor_table->push_back(row);
|
||||
} else if (rc.IsError()) {
|
||||
return rc;
|
||||
|
@ -165,7 +165,7 @@ Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
|
|||
// If not all of them can be sent async, return what's left back to the caller.
|
||||
if (!tensor_table->empty()) {
|
||||
in->set_tensor_table(std::move(tensor_table));
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
return Status(StatusCode::kMDNotImplementedYet);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -225,7 +225,8 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
|||
auto cache_state = static_cast<CacheServiceState>(out);
|
||||
if (cache_state == CacheServiceState::kFetchPhase ||
|
||||
(cache_state == CacheServiceState::kBuildPhase && cookie_.empty())) {
|
||||
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
|
||||
return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
|
||||
"Not an error and we should bypass the build phase");
|
||||
}
|
||||
} else {
|
||||
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
|
||||
|
@ -243,10 +244,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
|||
auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
|
||||
RETURN_IF_NOT_OK(PushRequest(rq));
|
||||
Status rc = rq->Wait();
|
||||
bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey);
|
||||
bool success = (rc.IsOk() || rc.StatusCode() == StatusCode::kMDDuplicateKey);
|
||||
// If we get kDuplicateKey, it just means we aren't the first one to create the cache,
|
||||
// and we will continue to parse the result.
|
||||
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
|
||||
RETURN_IF_NOT_OK(rq->PostReply());
|
||||
}
|
||||
if (success) {
|
||||
|
@ -443,7 +444,7 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
|
|||
}
|
||||
// If the size is too big, tell the user to send it directly.
|
||||
if (sz > kAsyncBufferSize) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
return Status(StatusCode::kMDNotImplementedYet);
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
// Check error from the server side while we have the lock;
|
||||
|
|
|
@ -66,7 +66,7 @@ enum class CacheServiceState : int8_t {
|
|||
/// \param rc[in] Status object
|
||||
/// \param reply[in/out] pointer to pre-allocated protobuf object
|
||||
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
|
||||
reply->set_rc(static_cast<int32_t>(rc.get_code()));
|
||||
reply->set_rc(static_cast<int32_t>(rc.StatusCode()));
|
||||
reply->set_msg(rc.ToString());
|
||||
}
|
||||
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
|
||||
|
|
|
@ -98,7 +98,7 @@ Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffer
|
|||
(*out_fbb) = std::move(fbb);
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
|
|||
std::unique_lock<std::mutex> lck(mux_);
|
||||
auto r = req_.emplace(seqNo, std::move(tag));
|
||||
if (!r.second) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
// Last step is to tag the request.
|
||||
|
@ -124,7 +124,7 @@ Status CacheClientGreeter::WorkerEntry() {
|
|||
} else {
|
||||
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
|
||||
}
|
||||
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg);
|
||||
Status remote_rc = Status(StatusCode::kMDNetWorkError, __LINE__, __FILE__, err_msg);
|
||||
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
|
||||
}
|
||||
// Notify the waiting thread.
|
||||
|
|
|
@ -25,7 +25,7 @@ Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
|
|||
shmkey = ftok(unix_path.data(), 'a');
|
||||
if (shmkey == (key_t)-1) {
|
||||
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
|
||||
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg);
|
||||
return Status(errno == ENOENT ? StatusCode::kMDFileNotExist : StatusCode::kMDUnexpectedError, errMsg);
|
||||
}
|
||||
*out = shmkey;
|
||||
return Status::OK();
|
||||
|
@ -56,7 +56,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
|
|||
CacheMsgBuf msg{
|
||||
1,
|
||||
};
|
||||
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
|
||||
msg.body.status.err_code = static_cast<int32_t>(rc.StatusCode());
|
||||
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
|
||||
msg.body.status.err_msg[rc.ToString().size()] = '\0';
|
||||
|
|
|
@ -25,16 +25,17 @@
|
|||
#include <chrono>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
namespace ms = mindspore;
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
/// Start the server
|
||||
/// \param argv
|
||||
/// \return Status object
|
||||
ds::Status StartServer(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ms::Status StartServer(int argc, char **argv) {
|
||||
ms::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
if (argc != 8) {
|
||||
return ds::Status(ds::StatusCode::kSyntaxError);
|
||||
return ms::Status(ms::StatusCode::kMDSyntaxError);
|
||||
}
|
||||
|
||||
int32_t port = strtol(argv[3], nullptr, 10);
|
||||
|
@ -53,7 +54,7 @@ ds::Status StartServer(int argc, char **argv) {
|
|||
// is called. This is a standard procedure for daemonize a process on unix.
|
||||
if (chdir("/") == -1) {
|
||||
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
|
||||
// A message queue for communication between parent and child (if we fork).
|
||||
|
@ -80,13 +81,13 @@ ds::Status StartServer(int argc, char **argv) {
|
|||
// failed to fork
|
||||
if (pid < 0) {
|
||||
std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else if (pid > 0) {
|
||||
// Parent and will be responsible for remove the queue on exit.
|
||||
msg.RemoveResourcesOnExit();
|
||||
// Sleep one second and we attach to the msg que
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
ds::Status child_rc;
|
||||
ms::Status child_rc;
|
||||
rc = msg.ReceiveStatus(&child_rc);
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
|
@ -101,7 +102,7 @@ ds::Status StartServer(int argc, char **argv) {
|
|||
"logs (under "
|
||||
<< ds::DefaultLogDir() << ") for any issues that may happen after startup\n";
|
||||
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
|
||||
return ds::Status::OK();
|
||||
return ms::Status::OK();
|
||||
} else {
|
||||
// Child process will continue from here if daemonize and parent has already exited.
|
||||
// If we are running in the foreground, none of the code in block below will be run.
|
||||
|
@ -110,7 +111,7 @@ ds::Status StartServer(int argc, char **argv) {
|
|||
sid = setsid();
|
||||
if (sid < 0) {
|
||||
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
close(0);
|
||||
close(1);
|
||||
|
@ -137,10 +138,10 @@ ds::Status StartServer(int argc, char **argv) {
|
|||
|
||||
int main(int argc, char **argv) {
|
||||
// This executable is not to be called directly, and should be invoked by cache_admin executable.
|
||||
ds::Status rc = StartServer(argc, argv);
|
||||
ms::Status rc = StartServer(argc, argv);
|
||||
// Check result
|
||||
if (rc.IsError()) {
|
||||
auto errCode = rc.get_code();
|
||||
auto errCode = rc.StatusCode();
|
||||
auto errMsg = rc.ToString();
|
||||
std::cerr << errMsg << std::endl;
|
||||
return static_cast<int>(errCode);
|
||||
|
|
|
@ -136,7 +136,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
|
|||
if (rc.IsOk()) {
|
||||
*p = ptr;
|
||||
break;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory) {
|
||||
inx = (inx + 1) % num_slots;
|
||||
} else {
|
||||
return rc;
|
||||
|
@ -162,7 +162,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
|
|||
if (rc.IsOk()) {
|
||||
*p = ptr;
|
||||
break;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory) {
|
||||
// Make the next arena and continue.
|
||||
slot = (slot + 1) % num_segments;
|
||||
} else {
|
||||
|
@ -172,7 +172,7 @@ Status NumaMemoryPool::Allocate(size_t n, void **p) {
|
|||
}
|
||||
// Handle the case we have done one round robin search.
|
||||
if (ptr == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
|
|||
bl.ptr = nullptr;
|
||||
return rc;
|
||||
}
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory) {
|
||||
// If no memory, write to disk.
|
||||
if (sm_ != nullptr) {
|
||||
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
|
||||
|
@ -116,7 +116,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
|
|||
} else {
|
||||
// If asked to spill to disk instead but there is no storage set up, simply return no memory
|
||||
// instead.
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data");
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data");
|
||||
}
|
||||
} else {
|
||||
return rc;
|
||||
|
@ -125,7 +125,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
|
|||
try {
|
||||
rc = tree_->DoInsert(key, bl);
|
||||
} catch (const std::bad_alloc &e) {
|
||||
rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
rc = Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
// Duplicate key is treated as error and we will also free the memory.
|
||||
if (rc.IsError() && bl.ptr != nullptr) {
|
||||
|
|
|
@ -223,7 +223,7 @@ Status CreateCacheRequest::Prepare() {
|
|||
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -277,7 +277,7 @@ Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<
|
|||
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
return Status::OK();
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -169,7 +169,7 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
|
|||
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
|
||||
max_avail -= mem_consumed;
|
||||
if (max_avail <= 0) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
++it;
|
||||
}
|
||||
|
@ -179,12 +179,12 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
|
|||
if (max_avail < avail_mem) {
|
||||
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
|
||||
if (req_mem > max_avail) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
} else if (req_mem == 0) {
|
||||
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
|
||||
// 85% of our limit, fail this request.
|
||||
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
client_id = cs->num_clients_.fetch_add(1);
|
||||
all_caches_.emplace(connection_id, std::move(cs));
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
return Status(StatusCode::kMDOutOfMemory);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,7 +276,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
|
||||
// treat it as OK.
|
||||
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK();
|
||||
return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::DestroyCache(CacheRequest *rq) {
|
||||
|
@ -306,7 +306,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto sz = rq->buf_data_size();
|
||||
std::vector<const void *> buffers;
|
||||
|
@ -326,7 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
|
||||
reply->set_result(std::to_string(id));
|
||||
} else {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -353,7 +353,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
Status rc;
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
|
||||
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
// Only if the cookie matches, we can accept insert into this cache that has a build phase
|
||||
if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
|
||||
|
@ -365,11 +365,11 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
} else {
|
||||
auto state = cs->GetState();
|
||||
if (state != CacheServiceState::kFetchPhase) {
|
||||
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Cache service is not in fetch phase. The current phase is " +
|
||||
std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id));
|
||||
} else {
|
||||
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Cookie mismatch. Client id: " + std::to_string(client_id));
|
||||
}
|
||||
}
|
||||
|
@ -413,7 +413,7 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
|
|||
Status rc;
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
|
@ -494,7 +494,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
|
||||
auto &row_id_buf = rq->buf_data(0);
|
||||
|
@ -551,7 +551,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
|
|||
mem.resize(mem_sz);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
return Status(StatusCode::kMDOutOfMemory);
|
||||
}
|
||||
WritableSlice dest(mem.data(), mem_sz);
|
||||
RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
|
||||
|
@ -568,7 +568,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
CacheService::ServiceStat svc_stat;
|
||||
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
|
||||
|
@ -595,7 +595,7 @@ Status CacheServer::CacheSchema(CacheRequest *rq) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
|
||||
auto &create_schema_buf = rq->buf_data(0);
|
||||
|
@ -611,7 +611,7 @@ Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
// We are going to use std::string to allocate and hold the result which will be eventually
|
||||
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
|
||||
|
@ -630,7 +630,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
// First piece of data is the cookie
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
|
||||
|
@ -639,7 +639,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
|
|||
if (cookie == cs->cookie()) {
|
||||
RETURN_IF_NOT_OK(cs->BuildPhaseDone());
|
||||
} else {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -652,7 +652,7 @@ Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
std::vector<row_id_type> gap;
|
||||
RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
|
||||
|
@ -680,7 +680,7 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
// First piece of data is the on/off flag
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
|
||||
|
@ -747,7 +747,7 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto client_id = rq->client_id();
|
||||
MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
|
||||
|
@ -836,7 +836,7 @@ Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *inter
|
|||
default:
|
||||
std::string errMsg("Internal error, request type is not row request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -860,7 +860,7 @@ Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
|
|||
default:
|
||||
std::string errMsg("Internal error, request type is not session request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -931,7 +931,7 @@ Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
|
|||
default:
|
||||
std::string errMsg("Internal error, request type is not admin request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -949,7 +949,7 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
|||
} else {
|
||||
std::string errMsg("Unknown request type : ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
|
||||
// Notify it is done, and move on to the next request.
|
||||
|
@ -1045,7 +1045,7 @@ Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
|
|||
RETURN_UNEXPECTED_IF_NULL(q);
|
||||
auto *p = new (std::nothrow) CacheServerRequest();
|
||||
if (p == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
*q = p;
|
||||
return Status::OK();
|
||||
|
@ -1091,7 +1091,7 @@ Status CacheServer::DestroySession(CacheRequest *rq) {
|
|||
} else {
|
||||
std::string errMsg =
|
||||
"Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
|
||||
return Status(StatusCode::kFileNotExist, errMsg);
|
||||
return Status(StatusCode::kMDFileNotExist, errMsg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1148,7 +1148,7 @@ Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
|
|||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto state = cs->GetState();
|
||||
reply->set_result(std::to_string(static_cast<int8_t>(state)));
|
||||
|
@ -1247,7 +1247,7 @@ Status CacheServer::Builder::IpcResourceCleanup() {
|
|||
std::string errMsg = "Cache server is already up and running";
|
||||
// We return a duplicate error. The main() will intercept
|
||||
// and output a proper message
|
||||
return Status(StatusCode::kDuplicateKey, errMsg);
|
||||
return Status(StatusCode::kMDDuplicateKey, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -419,7 +419,7 @@ class CacheServer : public Service {
|
|||
Status GetRc() {
|
||||
Status rc;
|
||||
for (auto &cache_rc : rc_lists_) {
|
||||
if (cache_rc.IsError() && !cache_rc.IsInterrupted() && rc.IsOk()) {
|
||||
if (cache_rc.IsError() && cache_rc != StatusCode::kMDInterrupted && rc.IsOk()) {
|
||||
rc = cache_rc;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ Status CacheService::DoServiceStart() {
|
|||
// Return an error if we use more than recommended memory.
|
||||
std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
|
||||
" while available system memory " + std::to_string(avail_mem);
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, errMsg);
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
if (st_ == CacheServiceState::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
// return out of memory from now on.
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
return Status(StatusCode::kMDOutOfMemory);
|
||||
}
|
||||
try {
|
||||
// The first buffer is a flatbuffer which describes the rest of the buffers follow
|
||||
|
@ -119,16 +119,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
}
|
||||
// Now we cache the buffer.
|
||||
Status rc = cp_->Insert(*row_id_generated, all_data);
|
||||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
if (rc == Status(StatusCode::kMDDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
if (HasBuildPhase()) {
|
||||
// For cache service that has a build phase, record the error in the state
|
||||
// so other clients can be aware of the new state. There is nothing one can
|
||||
// do to resume other than to drop the cache.
|
||||
if (rc.IsNoSpace()) {
|
||||
if (rc == StatusCode::kMDNoSpace) {
|
||||
st_ = CacheServiceState::kNoSpace;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory) {
|
||||
st_ = CacheServiceState::kOutOfMemory;
|
||||
}
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
|
|||
if (st_ == CacheServiceState::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
// return out of memory from now on.
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
return Status(StatusCode::kMDOutOfMemory);
|
||||
}
|
||||
try {
|
||||
// If we don't need to generate id, we need to find it from the buffer.
|
||||
|
@ -172,16 +172,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
|
|||
}
|
||||
// Now we cache the buffer.
|
||||
Status rc = cp_->Insert(*row_id_generated, {src});
|
||||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
if (rc == Status(StatusCode::kMDDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
if (HasBuildPhase()) {
|
||||
// For cache service that has a build phase, record the error in the state
|
||||
// so other clients can be aware of the new state. There is nothing one can
|
||||
// do to resume other than to drop the cache.
|
||||
if (rc.IsNoSpace()) {
|
||||
if (rc == StatusCode::kMDNoSpace) {
|
||||
st_ = CacheServiceState::kNoSpace;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory) {
|
||||
st_ = CacheServiceState::kOutOfMemory;
|
||||
}
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ Status CacheService::FetchSchema(std::string *out) const {
|
|||
if (!mem.empty()) {
|
||||
*out = std::move(mem);
|
||||
} else {
|
||||
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
|
||||
return Status(StatusCode::kMDFileNotExist, __LINE__, __FILE__, "No schema has been cached");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ Status CachePerfMsg::Receive(int32_t qID) {
|
|||
auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR);
|
||||
if (err == -1) {
|
||||
if (errno == EIDRM) {
|
||||
return Status(StatusCode::kInterrupted);
|
||||
return Status(StatusCode::kMDInterrupted);
|
||||
} else {
|
||||
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
|
|
|
@ -33,7 +33,7 @@ int main(int argc, char **argv) {
|
|||
if (rc.IsError()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
}
|
||||
return static_cast<int>(rc.get_code());
|
||||
return static_cast<int>(rc.StatusCode());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -100,5 +100,7 @@ class CachePerfRun {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
// todo: waiting for the master of the codes to refactor
|
||||
#define get_code StatusCode
|
||||
#define kDuplicateKey kMDDuplicateKey
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_
|
||||
|
|
|
@ -33,12 +33,12 @@ int main(int argc, char **argv) {
|
|||
// If we hit any error, send the rc back to the parent.
|
||||
if (rc.IsError()) {
|
||||
ds::ErrorMsg proto;
|
||||
proto.set_rc(static_cast<int32_t>(rc.get_code()));
|
||||
proto.set_rc(static_cast<int32_t>(rc.StatusCode()));
|
||||
proto.set_msg(rc.ToString());
|
||||
ds::CachePerfMsg msg;
|
||||
(void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto);
|
||||
}
|
||||
return static_cast<int>(rc.get_code());
|
||||
return static_cast<int>(rc.StatusCode());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -115,5 +115,9 @@ class CachePipelineRun {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
// todo: waiting for the master of the codes to refactor
|
||||
#define get_code StatusCode
|
||||
#define kDuplicateKey kMDDuplicateKey
|
||||
#define IsOutofMemory() StatusCode() == StatusCode::kMDOutOfMemory
|
||||
#define IsNoSpace() StatusCode() == StatusCode::kMDNoSpace
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_
|
||||
|
|
|
@ -104,7 +104,7 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const
|
|||
if (r_sz != sz) {
|
||||
errno_t err = (r_sz == 0) ? EOF : errno;
|
||||
if (errno == ENOSPC) {
|
||||
return Status(StatusCode::kNoSpace, __LINE__, __FILE__);
|
||||
return Status(StatusCode::kMDNoSpace, __LINE__, __FILE__);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(strerror(err));
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer
|
|||
Status rc;
|
||||
auto sc = new (std::nothrow) StorageContainer(path);
|
||||
if (sc == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
return Status(StatusCode::kMDOutOfMemory);
|
||||
}
|
||||
rc = sc->Create();
|
||||
if (rc.IsOk()) {
|
||||
|
|
|
@ -96,9 +96,9 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
|
|||
cont = containers_.at(num_containers - 1);
|
||||
off64_t offset;
|
||||
Status rc = cont->Insert(buf, &offset);
|
||||
if (rc.get_code() == StatusCode::kBuddySpaceFull) {
|
||||
if (rc.StatusCode() == StatusCode::kMDBuddySpaceFull) {
|
||||
create_new_container = true;
|
||||
// Remember how many containers we saw. In the next iteration we will do a comparision to see
|
||||
// Remember how many containers we saw. In the next iteration we will do a comparison to see
|
||||
// if someone has already created it.
|
||||
last_num_container = num_containers;
|
||||
} else if (rc.IsOk()) {
|
||||
|
|
|
@ -140,7 +140,7 @@ Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *
|
|||
// If we already had an unknown dimension, then we cannot have a second unknown dimension.
|
||||
// We only support the compute of a single unknown dim.
|
||||
if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Requested shape has more than one unknown dimension!");
|
||||
}
|
||||
|
||||
|
@ -312,12 +312,12 @@ Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::strin
|
|||
}
|
||||
// data type is mandatory field
|
||||
if (type_str.empty())
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"json schema file for column " + col_name + " has invalid or missing column type.");
|
||||
|
||||
// rank number is mandatory field
|
||||
if (rank_value <= -1)
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"json schema file for column " + col_name + " must define a positive rank value.");
|
||||
|
||||
// Create the column descriptor for this column from the data we pulled from the json file
|
||||
|
@ -425,7 +425,7 @@ Status DataSchema::AddColumn(const ColDescriptor &cd) {
|
|||
Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
|
||||
// Check if columns node exists. It is required for building schema from file.
|
||||
if (js.find("columns") == js.end())
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"\"columns\" node is required in the schema json file.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -434,12 +434,12 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
|
|||
// name to column index number.
|
||||
Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) {
|
||||
if (out_column_name_map == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map.");
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < col_descs_.size(); ++i) {
|
||||
if (col_descs_[i].name().empty()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Constructing column name map from schema, but found empty column name.");
|
||||
}
|
||||
(*out_column_name_map)[col_descs_[i].name()] = i;
|
||||
|
|
|
@ -290,7 +290,7 @@ Status ChildIterator::Drain() {
|
|||
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
|
||||
}
|
||||
if (curr_buffer_->eof()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -122,7 +122,8 @@ Status BarrierOp::prepare(TensorQTable *const table) {
|
|||
clean_up_ = false;
|
||||
buffer_id_ = 0;
|
||||
if (table == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"BarrierOp prepare phase requires a tensor table.");
|
||||
}
|
||||
// fill initial row
|
||||
TensorRow new_row = {};
|
||||
|
@ -150,7 +151,7 @@ Status BarrierOp::prepare(TensorQTable *const table) {
|
|||
// fillBuffer always expects a new table to fill
|
||||
Status BarrierOp::fillBuffer(TensorQTable *const table) {
|
||||
if (table == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
|
||||
}
|
||||
TensorRow new_row = {};
|
||||
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
|
||||
|
@ -172,7 +173,7 @@ Status BarrierOp::blockCond() {
|
|||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
// we have condition name, however the flexibility is in python today
|
||||
try {
|
||||
|
@ -180,11 +181,11 @@ Status BarrierOp::blockCond() {
|
|||
py::object ret_py_obj = condition_function_();
|
||||
// Process the return value
|
||||
if (!py::isinstance<py::bool_>(ret_py_obj)) {
|
||||
return Status(StatusCode::kPyFuncException,
|
||||
return Status(StatusCode::kMDPyFuncException,
|
||||
"Invalid parameter, condition wait function should return true/false.");
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
return Status(StatusCode::kMDPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -61,7 +61,7 @@ Status BatchOp::Builder::SanityCheck() {
|
|||
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
|
||||
std::to_string(builder_num_workers_) + ".\n"
|
||||
: "";
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -261,7 +261,7 @@ Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatc
|
|||
|
||||
Status BatchOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -338,23 +338,23 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
|
|||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
}
|
||||
try {
|
||||
py::object size = batch_size_func_(info);
|
||||
*batch_size = size.cast<int32_t>();
|
||||
if (*batch_size <= 0) {
|
||||
return Status(StatusCode::kPyFuncException,
|
||||
return Status(StatusCode::kMDPyFuncException,
|
||||
"Invalid parameter, batch size function should return an integer greater than 0.");
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
return Status(StatusCode::kMDPyFuncException, e.what());
|
||||
} catch (const py::cast_error &e) {
|
||||
return Status(StatusCode::kPyFuncException,
|
||||
return Status(StatusCode::kMDPyFuncException,
|
||||
"Invalid parameter, batch size function should return an integer greater than 0.");
|
||||
}
|
||||
}
|
||||
return Status(StatusCode::kOK, "Batch size func call succeed.");
|
||||
return Status(StatusCode::kSuccess, "Batch size func call succeed.");
|
||||
}
|
||||
|
||||
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
|
||||
|
@ -362,7 +362,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
|||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
}
|
||||
try {
|
||||
// Prepare batch map call back parameters
|
||||
|
@ -407,9 +407,9 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
|||
output->push_back(std::move(output_batch));
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
return Status(StatusCode::kMDPyFuncException, e.what());
|
||||
} catch (const py::cast_error &e) {
|
||||
return Status(StatusCode::kPyFuncException,
|
||||
return Status(StatusCode::kMDPyFuncException,
|
||||
"Invalid parameter, batch map function should return a tuple of list of numpy array.");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba
|
|||
if (bucket_index + 1 >= bucket_boundaries_.size()) {
|
||||
std::string error_message =
|
||||
"Invalid data, requested to pad to bucket boundary, element falls in last bucket.";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message);
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, error_message);
|
||||
}
|
||||
|
||||
pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1;
|
||||
|
|
|
@ -42,7 +42,7 @@ BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePie
|
|||
|
||||
Status BuildSentencePieceVocabOp::operator()() {
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
|
||||
|
@ -84,10 +84,10 @@ Status BuildSentencePieceVocabOp::SentenceThread() {
|
|||
sentencepiece::util::Status s_status =
|
||||
sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
|
||||
if (!s_status.ok()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message());
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, s_status.message());
|
||||
} else {
|
||||
if (vocab_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, sentencepiece vocab not set.");
|
||||
}
|
||||
vocab_->set_model_proto(model_proto);
|
||||
|
@ -145,7 +145,7 @@ void BuildSentencePieceVocabOp::Next(std::string *sentence) {
|
|||
|
||||
if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
|
||||
ret_status_ =
|
||||
Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid data, build_sentence_piece_vocab only works on string data with rank equal to 1, got type: " +
|
||||
new_row[col_id_]->type().ToString() + "and rank: " + std::to_string(new_row[col_id_]->Rank()));
|
||||
read_done_ = true;
|
||||
|
|
|
@ -80,7 +80,7 @@ Status BuildVocabOp::WorkerEntry(int32_t worker_id) {
|
|||
Status BuildVocabOp::operator()() {
|
||||
// launch the collector thread
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks()));
|
||||
|
|
|
@ -233,7 +233,7 @@ Status CacheBase::UpdateColumnMapFromCache() {
|
|||
// Get the schema from the server. It may not be there yet. So tolerate the error.
|
||||
if (column_name_id_map_.empty()) {
|
||||
rc = cache_client_->FetchSchema(&column_name_id_map_);
|
||||
if (rc == Status(StatusCode::kFileNotExist)) {
|
||||
if (rc == Status(StatusCode::kMDFileNotExist)) {
|
||||
MS_LOG(DEBUG) << "Schema not in the server yet.";
|
||||
rc = Status::OK();
|
||||
}
|
||||
|
@ -304,14 +304,14 @@ Status CacheBase::Prefetcher(int32_t worker_id) {
|
|||
int32_t retry_count = 0;
|
||||
do {
|
||||
rc = PrefetchRows(prefetch_keys, &cache_miss);
|
||||
if (rc.IsNetWorkError() && retry_count < max_retries) {
|
||||
if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
|
||||
// If we get some network error, we will attempt some retries
|
||||
retry_count++;
|
||||
} else if (rc.IsError()) {
|
||||
MS_LOG(WARNING) << rc.ToString();
|
||||
return rc;
|
||||
}
|
||||
} while (rc.IsNetWorkError());
|
||||
} while (rc == StatusCode::kMDNetWorkError);
|
||||
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
|
||||
// we will put an empty row in the local cache.
|
||||
if (rc.IsError() && AllowCacheMiss()) {
|
||||
|
|
|
@ -39,12 +39,12 @@ CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_
|
|||
// Check if the required parameters are set by the builder.
|
||||
Status CacheLookupOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -59,7 +59,7 @@ Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
|
|||
}
|
||||
Status CacheLookupOp::operator()() {
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheLookupOp requires a sampler before it can be executed, but got nullptr.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(RegisterResources());
|
||||
|
|
|
@ -129,7 +129,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
Status rc;
|
||||
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
|
||||
cache_missing_rows_ = false;
|
||||
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
|
||||
|
@ -156,7 +156,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
rc = rq->AsyncSendCacheRequest(cache_client_, row);
|
||||
if (rc.IsOk()) {
|
||||
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
|
||||
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
|
||||
cache_missing_rows_ = false;
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
}
|
||||
|
@ -188,9 +188,9 @@ Status CacheMergeOp::Cleaner() {
|
|||
Status rc = rq->CheckCacheResult();
|
||||
if (rc.IsError()) {
|
||||
// If interrupt, time to quit.
|
||||
if (rc.IsInterrupted()) {
|
||||
if (rc == StatusCode::kMDInterrupted) {
|
||||
return Status::OK();
|
||||
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
} else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
|
||||
// The server is hitting some limit and we will turn off caching from now on.
|
||||
cache_missing_rows_ = false;
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
|
@ -215,7 +215,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
|
|||
// Construct the cache
|
||||
const bool generate_ids = false;
|
||||
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
|
||||
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
|
||||
// We are told the cache has been created already.
|
||||
MS_LOG(INFO) << "Cache created already";
|
||||
rc = Status::OK();
|
||||
|
@ -244,12 +244,12 @@ CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(
|
|||
// Check if the required parameters are set by the builder.
|
||||
Status CacheMergeOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -316,7 +316,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
|
|||
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
|
||||
Status rc;
|
||||
rc = cc->AsyncWriteRow(row);
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
|
||||
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
|
||||
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
|
||||
if (rc.IsOk()) {
|
||||
|
|
|
@ -41,12 +41,12 @@ CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullp
|
|||
// Check if the required parameters are set by the builder.
|
||||
Status CacheOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -78,7 +78,7 @@ Status CacheOp::InitCache() { return Status::OK(); }
|
|||
// This class functor will provide the master loop that drives the logic for performing the work
|
||||
Status CacheOp::operator()() {
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheOp requires a sampler before it can be executed, but got nullptr.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(RegisterResources());
|
||||
|
@ -113,7 +113,7 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
|
|||
Status rc;
|
||||
// Do the Async write if we attach to the shared memory.
|
||||
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
|
||||
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
|
||||
} else if (rc.IsError()) {
|
||||
return rc;
|
||||
|
@ -169,9 +169,9 @@ Status CacheOp::WaitForCachingAllRows() {
|
|||
BuildPhaseDone = true;
|
||||
break;
|
||||
case CacheServiceState::kOutOfMemory:
|
||||
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory");
|
||||
return Status(StatusCode::kMDOutOfMemory, "Cache server is running out of memory");
|
||||
case CacheServiceState::kNoSpace:
|
||||
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage");
|
||||
return Status(StatusCode::kMDNoSpace, "Cache server is running of out spill storage");
|
||||
case CacheServiceState::kNone:
|
||||
case CacheServiceState::kError:
|
||||
default:
|
||||
|
@ -246,7 +246,7 @@ Status CacheOp::PrepareNodePostAction() {
|
|||
// Construct the cache
|
||||
const bool generate_ids = true;
|
||||
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
|
||||
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||
if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
|
||||
// We are told the cache has been created already. So we skip the build phase.
|
||||
phase_ = Phase::kFetchPhase;
|
||||
rc = Status::OK();
|
||||
|
|
|
@ -157,18 +157,14 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
TensorRow currRow;
|
||||
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
|
||||
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
|
||||
while (stop_send_ && ascend_keep_waiting_) {
|
||||
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
||||
}
|
||||
WaitContinueSignal();
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
if (create_data_info_queue_) {
|
||||
DATA_INFO data_info;
|
||||
|
@ -200,9 +196,8 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
return Status(StatusCode::kMDTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||
stop_send_ = true;
|
||||
|
@ -219,13 +214,19 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
void DeviceQueueOp::WaitContinueSignal() const {
|
||||
while (stop_send_ && ascend_keep_waiting_) {
|
||||
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
|
||||
if (!create_data_info_queue_) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
|
||||
}
|
||||
// This place has a race condition with operator(), so the first one
|
||||
// arrive here will do the initialize work.
|
||||
|
@ -241,7 +242,7 @@ Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
|
|||
}
|
||||
#else
|
||||
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -301,7 +302,7 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
}
|
||||
handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, release_function);
|
||||
if (handle == INVALID_HANDLE) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Failed to open channel for sending data.");
|
||||
}
|
||||
is_open = true;
|
||||
}
|
||||
|
@ -309,14 +310,14 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
// Data prefetch only when PS mode enables cache.
|
||||
if (items.size() > 0) {
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
|
||||
return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
|
||||
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
|
||||
}
|
||||
}
|
||||
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||
if (ret) {
|
||||
if (ret == BlockQueueStatus_T::ERROR_INPUT) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Invalid input data, please check it.");
|
||||
} else {
|
||||
if (!stop_send_) {
|
||||
MS_LOG(DEBUG) << "Retry pushing data...";
|
||||
|
@ -438,13 +439,13 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
|
|||
for (auto &sub_item : *items) {
|
||||
RETURN_IF_NOT_OK(pool_[worker_id]->Allocate(sub_item.data_len_, &sub_item.data_ptr_));
|
||||
if (sub_item.data_ptr_ == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
|
||||
}
|
||||
const unsigned char *column_data = curr_row[i]->GetBuffer();
|
||||
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
|
||||
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed!";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "memcpy_s failed.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -190,6 +190,7 @@ class DeviceQueueOp : public PipelineOp {
|
|||
|
||||
private:
|
||||
#ifdef ENABLE_TDTQUE
|
||||
void WaitContinueSignal() const;
|
||||
Status SendDataToAscend();
|
||||
bool ascend_keep_waiting_;
|
||||
#endif
|
||||
|
|
|
@ -43,7 +43,7 @@ Status FilterOp::Builder::SanityCheck() {
|
|||
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
|
||||
std::to_string(builder_num_workers_) + ".\n"
|
||||
: "";
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
}
|
||||
|
||||
FilterOp::Builder::Builder() {
|
||||
|
@ -66,7 +66,7 @@ FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_wor
|
|||
Status FilterOp::operator()() {
|
||||
// The operator class just starts off threads by calling the tree_ function.
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
filter_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
|
||||
|
@ -244,7 +244,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
RETURN_IF_NOT_OK(predicate_func_->Compute(input, &output));
|
||||
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_predicate, {}));
|
||||
|
||||
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
|
||||
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue