forked from mindspore-Ecosystem/mindspore
add c api
This commit is contained in:
parent
df25ee8c68
commit
1615edc359
|
@ -223,6 +223,8 @@ if(PLATFORM_ARM64)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
__install_micro_wrapper()
|
||||
if(MSLITE_ENABLE_TOOLS)
|
||||
install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -298,6 +300,8 @@ elseif(PLATFORM_ARM32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
__install_micro_wrapper()
|
||||
if(MSLITE_ENABLE_TOOLS AND NOT TARGET_OHOS_LITE)
|
||||
install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -363,6 +367,8 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
if(MSVC)
|
||||
install(FILES ${TOP_DIR}/build/mindspore/src/${MINDSPORE_LITE_LIB_NAME}.lib DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -405,6 +411,8 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||
|
|
|
@ -41,6 +41,7 @@ class DeviceInfoContext;
|
|||
/// \brief Context is used to store environment variables during execution.
|
||||
class MS_API Context {
|
||||
public:
|
||||
struct Data;
|
||||
Context();
|
||||
~Context() = default;
|
||||
|
||||
|
@ -104,7 +105,6 @@ class MS_API Context {
|
|||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_CONTEXT_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_CONTEXT_C_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include "include/c_api/types_c.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void *MSContextHandle;
|
||||
typedef void *MSDeviceInfoHandle;
|
||||
|
||||
/// \brief Create a context object.
|
||||
///
|
||||
/// \return Context object handle.
|
||||
MS_API MSContextHandle MSContextCreate();
|
||||
|
||||
/// \brief Destroy the context object.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
MS_API void MSContextDestroy(MSContextHandle context);
|
||||
|
||||
/// \brief Set the number of threads at runtime.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[in] thread_num the number of threads at runtime.
|
||||
MS_API void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num);
|
||||
|
||||
/// \brief Obtain the current thread number setting.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
///
|
||||
/// \return The current thread number setting.
|
||||
MS_API int32_t MSContextGetThreadNum(const MSContextHandle context);
|
||||
|
||||
/// \brief Set the thread affinity to CPU cores.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||
MS_API void MSContextSetThreadAffinityMode(MSContextHandle context, int mode);
|
||||
|
||||
/// \brief Obtain the thread affinity of CPU cores.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
///
|
||||
/// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
|
||||
MS_API int MSContextGetThreadAffinityMode(const MSContextHandle context);
|
||||
|
||||
/// \brief Set the thread lists to CPU cores.
|
||||
///
|
||||
/// \note If core_list and mode are set by MSContextSetThreadAffinityMode at the same time,
|
||||
/// the core_list is effective, but the mode is not effective.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[in] core_list: a array of thread core lists.
|
||||
/// \param[in] core_num The number of core.
|
||||
MS_API void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *core_list, size_t core_num);
|
||||
|
||||
/// \brief Obtain the thread lists of CPU cores.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[out] core_num The number of core.
|
||||
///
|
||||
/// \return a array of thread core lists.
|
||||
MS_API const int32_t *MSContextGetThreadAffinityCoreList(const MSContextHandle context, size_t *core_num);
|
||||
|
||||
/// \brief Set the status whether to perform model inference or training in parallel.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[in] is_parallel: true, parallel; false, not in parallel.
|
||||
MS_API void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel);
|
||||
|
||||
/// \brief Obtain the status whether to perform model inference or training in parallel.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
///
|
||||
/// \return Bool value that indicates whether in parallel.
|
||||
MS_API bool MSContextGetEnableParallel(const MSContextHandle context);
|
||||
|
||||
/// \brief Add device info to context object.
|
||||
///
|
||||
/// \param[in] context Context object handle.
|
||||
/// \param[in] device_info Device info object handle.
|
||||
MS_API void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Create a device info object.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return Device info object handle.
|
||||
MS_API MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type);
|
||||
|
||||
/// \brief Destroy the device info object.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
MS_API void MSDeviceInfoDestroy(MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Set provider's name.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
/// \param[in] provider define the provider's name.
|
||||
MS_API void MSDeviceInfoSetProvider(MSDeviceInfoHandle device_info, const char *provider);
|
||||
|
||||
/// \brief Obtain provider's name
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return provider's name.
|
||||
MS_API const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Set provider's device type.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
/// \param[in] device define the provider's device type. EG: CPU.
|
||||
MS_API void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *device);
|
||||
|
||||
/// \brief Obtain provider's device type.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return provider's device type.
|
||||
MS_API const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Obtain the device type of the device info.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return Device Type of the device info.
|
||||
MS_API MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Set enables to perform the float16 inference, Only valid for CPU/GPU.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
/// \param[in] is_fp16 Enable float16 inference or not.
|
||||
MS_API void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16);
|
||||
|
||||
/// \brief Obtain enables to perform the float16 inference, Only valid for CPU/GPU.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return Whether enable float16 inference.
|
||||
MS_API bool MSDeviceInfoGetEnableFP16(const MSDeviceInfoHandle device_info);
|
||||
|
||||
/// \brief Set the NPU frequency, Only valid for NPU.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
|
||||
/// performance), default as 3.
|
||||
MS_API void MSDeviceInfoSetFrequency(MSDeviceInfoHandle device_info, int frequency);
|
||||
|
||||
/// \brief Obtain the NPU frequency, Only valid for NPU.
|
||||
///
|
||||
/// \param[in] device_info Device info object handle.
|
||||
///
|
||||
/// \return NPU frequency
|
||||
MS_API int MSDeviceInfoGetFrequency(const MSDeviceInfoHandle device_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_CONTEXT_C_H
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_DATA_TYPE_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_DATA_TYPE_C_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef enum MSDataType {
|
||||
kMSDataTypeUnknown = 0,
|
||||
kMSDataTypeObjectTypeString = 12,
|
||||
kMSDataTypeObjectTypeList = 13,
|
||||
kMSDataTypeObjectTypeTuple = 14,
|
||||
kMSDataTypeObjectTypeTensor = 17,
|
||||
kMSDataTypeNumberTypeBegin = 29,
|
||||
kMSDataTypeNumberTypeBool = 30,
|
||||
kMSDataTypeNumberTypeInt8 = 32,
|
||||
kMSDataTypeNumberTypeInt16 = 33,
|
||||
kMSDataTypeNumberTypeInt32 = 34,
|
||||
kMSDataTypeNumberTypeInt64 = 35,
|
||||
kMSDataTypeNumberTypeUInt8 = 37,
|
||||
kMSDataTypeNumberTypeUInt16 = 38,
|
||||
kMSDataTypeNumberTypeUInt32 = 39,
|
||||
kMSDataTypeNumberTypeUInt64 = 40,
|
||||
kMSDataTypeNumberTypeFloat16 = 42,
|
||||
kMSDataTypeNumberTypeFloat32 = 43,
|
||||
kMSDataTypeNumberTypeFloat64 = 44,
|
||||
kMSDataTypeNumberTypeEnd = 46,
|
||||
// add new enum here
|
||||
kMSDataTypeInvalid = INT32_MAX,
|
||||
} MSDataType;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_DATA_TYPE_C_H
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_FORMAT_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_FORMAT_C_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef enum MSFormat {
|
||||
kMSFormatNCHW = 0,
|
||||
kMSFormatNHWC = 1,
|
||||
kMSFormatNHWC4 = 2,
|
||||
kMSFormatHWKC = 3,
|
||||
kMSFormatHWCK = 4,
|
||||
kMSFormatKCHW = 5,
|
||||
kMSFormatCKHW = 6,
|
||||
kMSFormatKHWC = 7,
|
||||
kMSFormatCHWK = 8,
|
||||
kMSFormatHW = 9,
|
||||
kMSFormatHW4 = 10,
|
||||
kMSFormatNC = 11,
|
||||
kMSFormatNC4 = 12,
|
||||
kMSFormatNC4HW4 = 13,
|
||||
kMSFormatNCDHW = 15,
|
||||
kMSFormatNWC = 16,
|
||||
kMSFormatNCW = 17
|
||||
} MSFormat;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_FORMAT_C_H
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_MODEL_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_MODEL_C_H
|
||||
|
||||
#include "include/c_api/tensor_c.h"
|
||||
#include "include/c_api/context_c.h"
|
||||
#include "include/c_api/status_c.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void *MSModelHandle;
|
||||
|
||||
typedef struct MSTensorHandleArray {
|
||||
size_t handle_num;
|
||||
MSTensorHandle *handle_list;
|
||||
} MSTensorHandleArray;
|
||||
|
||||
#define MS_MAX_SHAPE_NUM 32
|
||||
typedef struct MSShapeInfo {
|
||||
size_t shape_num;
|
||||
int64_t shape[MS_MAX_SHAPE_NUM];
|
||||
} MSShapeInfo;
|
||||
|
||||
typedef struct MSCallBackParamC {
|
||||
char *node_name;
|
||||
char *node_type;
|
||||
} MSCallBackParamC;
|
||||
|
||||
typedef bool (*MSKernelCallBackC)(const MSTensorHandleArray inputs, const MSTensorHandleArray outputs,
|
||||
const MSCallBackParamC kernel_Info);
|
||||
|
||||
/// \brief Create a model object. Only valid for Lite.
|
||||
///
|
||||
/// \return Model object handle.
|
||||
MS_API MSModelHandle MSModelCreate();
|
||||
|
||||
/// \brief Destroy the model object. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
MS_API void MSModelDestroy(MSModelHandle model);
|
||||
|
||||
/// \brief Build the model from model file buffer so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] model_data Define the buffer read from a model file.
|
||||
/// \param[in] data_size Define bytes number of model file buffer.
|
||||
/// \param[in] model_type Define The type of model file.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
///
|
||||
/// \return MSStatus.
|
||||
MS_API MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_size, MSModelType model_type,
|
||||
const MSContextHandle model_context);
|
||||
|
||||
/// \brief Load and build the model from model path so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] model_path Define the model file path.
|
||||
/// \param[in] model_type Define The type of model file.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
///
|
||||
/// \return MSStatus.
|
||||
MS_API MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type,
|
||||
const MSContextHandle model_context);
|
||||
|
||||
/// \brief Resizes the shapes of inputs.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] inputs The array that includes all input tensor handles.
|
||||
/// \param[in] shape_infos Defines the new shapes of inputs, should be consistent with inputs.
|
||||
/// \param[in] shape_info_num The num of shape_infos.
|
||||
///
|
||||
/// \return MSStatus.
|
||||
MS_API MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos,
|
||||
size_t shape_info_num);
|
||||
|
||||
/// \brief Inference model.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] inputs The array that includes all input tensor handles.
|
||||
/// \param[out] outputs The array that includes all output tensor handles.
|
||||
/// \param[in] before CallBack before predict.
|
||||
/// \param[in] after CallBack after predict.
|
||||
///
|
||||
/// \return MSStatus.
|
||||
MS_API MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs,
|
||||
const MSKernelCallBackC before, const MSKernelCallBackC after);
|
||||
|
||||
/// \brief Obtains all input tensor handles of the model.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
///
|
||||
/// \return The array that includes all input tensor handles.
|
||||
MS_API MSTensorHandleArray MSModelGetInputs(const MSModelHandle model);
|
||||
|
||||
/// \brief Obtains all output tensor handles of the model.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
///
|
||||
/// \return The array that includes all output tensor handles.
|
||||
MS_API MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model);
|
||||
|
||||
/// \brief Obtains the input tensor handle of the model by name.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] tensor_name The name of tensor.
|
||||
///
|
||||
/// \return The input tensor handle with the given name, if the name is not found, an NULL is returned.
|
||||
MS_API MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char *tensor_name);
|
||||
|
||||
/// \brief Obtains the output tensor handle of the model by name.
|
||||
///
|
||||
/// \param[in] model Model object handle.
|
||||
/// \param[in] tensor_name The name of tensor.
|
||||
///
|
||||
/// \return The output tensor handle with the given name, if the name is not found, an NULL is returned.
|
||||
MS_API MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const char *tensor_name);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_MODEL_C_H
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_STATUS_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_STATUS_C_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
enum MSCompCode {
|
||||
kMSCompCodeCore = 0x00000000u,
|
||||
kMSCompCodeMD = 0x10000000u,
|
||||
kMSCompCodeME = 0x20000000u,
|
||||
kMSCompCodeMC = 0x30000000u,
|
||||
kMSCompCodeLite = 0xF0000000u,
|
||||
};
|
||||
|
||||
typedef enum MSStatus {
|
||||
kMSStatusSuccess = 0,
|
||||
// Core
|
||||
kMSStatusCoreFailed = kMSCompCodeCore | 0x1,
|
||||
|
||||
// Lite // Common error code, range: [-1, -100)
|
||||
kMSStatusLiteError = kMSCompCodeLite | (0x0FFFFFFF & -1), /**< Common error code. */
|
||||
kMSStatusLiteNullptr = kMSCompCodeLite | (0x0FFFFFFF & -2), /**< NULL pointer returned.*/
|
||||
kMSStatusLiteParamInvalid = kMSCompCodeLite | (0x0FFFFFFF & -3), /**< Invalid parameter.*/
|
||||
kMSStatusLiteNoChange = kMSCompCodeLite | (0x0FFFFFFF & -4), /**< No change. */
|
||||
kMSStatusLiteSuccessExit = kMSCompCodeLite | (0x0FFFFFFF & -5), /**< No error but exit. */
|
||||
kMSStatusLiteMemoryFailed = kMSCompCodeLite | (0x0FFFFFFF & -6), /**< Fail to create memory. */
|
||||
kMSStatusLiteNotSupport = kMSCompCodeLite | (0x0FFFFFFF & -7), /**< Fail to support. */
|
||||
kMSStatusLiteThreadPoolError = kMSCompCodeLite | (0x0FFFFFFF & -8), /**< Error occur in thread pool. */
|
||||
kMSStatusLiteUninitializedObj = kMSCompCodeLite | (0x0FFFFFFF & -9), /**< Object is not initialized. */
|
||||
|
||||
// Executor error code, range: [-100,-200)
|
||||
kMSStatusLiteOutOfTensorRange = kMSCompCodeLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
||||
kMSStatusLiteInputTensorError = kMSCompCodeLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
|
||||
kMSStatusLiteReentrantError = kMSCompCodeLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
|
||||
|
||||
// Graph error code, range: [-200,-300)
|
||||
kMSStatusLiteGraphFileError = kMSCompCodeLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
|
||||
|
||||
// Node error code, range: [-300,-400)
|
||||
kMSStatusLiteNotFindOp = kMSCompCodeLite | (0x0FFFFFFF & -300), /**< Failed to find operator. */
|
||||
kMSStatusLiteInvalidOpName = kMSCompCodeLite | (0x0FFFFFFF & -301), /**< Invalid operator name. */
|
||||
kMSStatusLiteInvalidOpAttr = kMSCompCodeLite | (0x0FFFFFFF & -302), /**< Invalid operator attr. */
|
||||
kMSStatusLiteOpExecuteFailure = kMSCompCodeLite | (0x0FFFFFFF & -303), /**< Failed to execution operator. */
|
||||
|
||||
// Tensor error code, range: [-400,-500)
|
||||
kMSStatusLiteFormatError = kMSCompCodeLite | (0x0FFFFFFF & -400), /**< Failed to checking tensor format. */
|
||||
|
||||
// InferShape error code, range: [-500,-600)
|
||||
kMSStatusLiteInferError = kMSCompCodeLite | (0x0FFFFFFF & -500), /**< Failed to infer shape. */
|
||||
kMSStatusLiteInferInvalid = kMSCompCodeLite | (0x0FFFFFFF & -501), /**< Invalid infer shape before runtime. */
|
||||
|
||||
// User input param error code, range: [-600, 700)
|
||||
kMSStatusLiteInputParamInvalid = kMSCompCodeLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
|
||||
} MSStatus;
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_STATUS_C_H
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_TENSOE_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_TENSOE_C_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include "include/c_api/types_c.h"
|
||||
#include "include/c_api/data_type_c.h"
|
||||
#include "include/c_api/format_c.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void *MSTensorHandle;
|
||||
|
||||
/// \brief Create a tensor object.
|
||||
///
|
||||
/// \param[in] name The name of the tensor.
|
||||
/// \param[in] type The data type of the tensor.
|
||||
/// \param[in] shape The shape of the tensor.
|
||||
/// \param[in] shape_num The num of the shape.
|
||||
/// \param[in] data The data pointer that points to allocated memory.
|
||||
/// \param[in] data_len The length of the memory, in bytes.
|
||||
///
|
||||
/// \return Tensor object handle.
|
||||
MS_API MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *shape, size_t shape_num,
|
||||
const void *data, size_t data_len);
|
||||
|
||||
/// \brief Destroy the tensor object.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
MS_API void MSTensorDestroy(MSTensorHandle tensor);
|
||||
|
||||
/// \brief Obtain a deep copy of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return Tensor object handle.
|
||||
MS_API MSTensorHandle MSTensorClone(MSTensorHandle tensor);
|
||||
|
||||
/// \brief Set the name for the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[in] name The name of the tensor.
|
||||
MS_API void MSTensorSetName(MSTensorHandle tensor, const char *name);
|
||||
|
||||
/// \brief Obtain the name of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The name of the tensor.
|
||||
MS_API const char *MSTensorGetName(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Set the data type for the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[in] type The data type of the tensor.
|
||||
MS_API void MSTensorSetDataType(MSTensorHandle tensor, MSDataType type);
|
||||
|
||||
/// \brief Obtain the data type of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The date type of the tensor.
|
||||
MS_API MSDataType MSTensorGetDataType(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Set the shape for the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[in] shape The shape array.
|
||||
/// \param[in] shape_num Dimension of shape.
|
||||
MS_API void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_num);
|
||||
|
||||
/// \brief Obtain the shape of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[out] shape_num Dimension of shape.
|
||||
///
|
||||
/// \return The shape array of the tensor.
|
||||
MS_API const int64_t *MSTensorGetShape(const MSTensorHandle tensor, size_t *shape_num);
|
||||
|
||||
/// \brief Set the format for the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[in] format The format of the tensor.
|
||||
MS_API void MSTensorSetFormat(MSTensorHandle tensor, MSFormat format);
|
||||
|
||||
/// \brief Obtain the format of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The format of the tensor.
|
||||
MS_API MSFormat MSTensorGetFormat(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Obtain the data for the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
/// \param[in] data A pointer to the data of the tensor.
|
||||
MS_API void MSTensorSetData(MSTensorHandle tensor, void *data);
|
||||
|
||||
/// \brief Obtain the data pointer of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The data pointer of the tensor.
|
||||
MS_API const void *MSTensorGetData(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Obtain the mutable data pointer of the tensor. If the internal data is empty, it will allocate memory.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The data pointer of the tensor.
|
||||
MS_API void *MSTensorGetMutableData(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Obtain the element number of the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The element number of the tensor.
|
||||
MS_API int64_t MSTensorGetElementNum(const MSTensorHandle tensor);
|
||||
|
||||
/// \brief Obtain the data size fo the tensor.
|
||||
///
|
||||
/// \param[in] tensor Tensor object handle.
|
||||
///
|
||||
/// \return The data size of the tensor.
|
||||
MS_API size_t MSTensorGetDataSize(const MSTensorHandle tensor);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_TENSOE_C_H
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_C_API_TYPES_C_H
|
||||
#define MINDSPORE_INCLUDE_C_API_TYPES_C_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifndef MS_API
|
||||
#ifdef _WIN32
|
||||
#define MS_API __declspec(dllexport)
|
||||
#else
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
typedef enum MSModelType {
|
||||
kMSModelTypeMindIR = 0,
|
||||
// insert new data type here
|
||||
kMSModelTypeInvalid = 0xFFFFFFFF
|
||||
} MSModelType;
|
||||
|
||||
typedef enum MSDeviceType {
|
||||
kMSDeviceTypeCPU = 0,
|
||||
kMSDeviceTypeGPU,
|
||||
kMSDeviceTypeKirinNPU,
|
||||
// add new type here
|
||||
kMSDeviceTypeInvalid = 100,
|
||||
} MSDeviceType;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_INCLUDE_C_API_TYPES_C_H
|
|
@ -66,9 +66,14 @@ file(GLOB CXX_API_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/*.cc
|
||||
)
|
||||
|
||||
file(GLOB C_API_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/c_api/*.cc
|
||||
)
|
||||
|
||||
set(API_SRC
|
||||
${CORE_DIR}/utils/status.cc
|
||||
${CXX_API_SRCS}
|
||||
${C_API_SRCS}
|
||||
)
|
||||
|
||||
file(GLOB CXX_API_TRAIN_SRCS
|
||||
|
|
|
@ -0,0 +1,262 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/context_c.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/delegate.h"
|
||||
#include "include/api/allocator.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
// ================ Context ================
|
||||
MSContextHandle MSContextCreate() {
|
||||
auto impl = new (std::nothrow) mindspore::Context::Data;
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "memory allocation failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<MSContextHandle>(impl);
|
||||
}
|
||||
|
||||
void MSContextDestroy(MSContextHandle context) {
|
||||
if (context != nullptr) {
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
delete impl;
|
||||
}
|
||||
}
|
||||
|
||||
void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->thread_num = thread_num;
|
||||
}
|
||||
|
||||
int32_t MSContextGetThreadNum(const MSContextHandle context) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
return impl->thread_num;
|
||||
}
|
||||
|
||||
void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->affinity_mode_ = mode;
|
||||
return;
|
||||
}
|
||||
|
||||
int MSContextGetThreadAffinityMode(const MSContextHandle context) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
return impl->affinity_mode_;
|
||||
}
|
||||
|
||||
void MSContextSetThreadAffinityCoreList(MSContextHandle context, const int32_t *core_list, size_t core_num) {
|
||||
if (context == nullptr || core_list == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->affinity_core_list_ = vec_core_list;
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t *MSContextGetThreadAffinityCoreList(const MSContextHandle context, size_t *core_num) {
|
||||
if (context == nullptr || core_num == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
*core_num = impl->affinity_core_list_.size();
|
||||
return impl->affinity_core_list_.data();
|
||||
}
|
||||
|
||||
void MSContextSetEnableParallel(MSContextHandle context, bool is_parallel) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
impl->enable_parallel_ = is_parallel;
|
||||
}
|
||||
|
||||
bool MSContextGetEnableParallel(const MSContextHandle context) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
return impl->enable_parallel_;
|
||||
}
|
||||
|
||||
void MSContextAddDeviceInfo(MSContextHandle context, MSDeviceInfoHandle device_info) {
|
||||
if (context == nullptr || device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::Context::Data *>(context);
|
||||
std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
|
||||
impl->device_info_list.push_back(device);
|
||||
}
|
||||
|
||||
// ================ DeviceInfo ================
|
||||
MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type) {
|
||||
mindspore::DeviceInfoContext *impl = nullptr;
|
||||
if (device_type == kMSDeviceTypeCPU) {
|
||||
impl = new (std::nothrow) mindspore::CPUDeviceInfo();
|
||||
} else if (device_type == kMSDeviceTypeGPU) {
|
||||
impl = new (std::nothrow) mindspore::GPUDeviceInfo();
|
||||
} else if (device_type == kMSDeviceTypeKirinNPU) {
|
||||
impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
|
||||
return nullptr;
|
||||
}
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "memory allocation failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<MSDeviceInfoHandle>(impl);
|
||||
}
|
||||
|
||||
void MSDeviceInfoDestroy(MSDeviceInfoHandle device_info) {
|
||||
if (device_info != nullptr) {
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
delete impl;
|
||||
}
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetProvider(MSDeviceInfoHandle device_info, const char *provider) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->SetProvider(provider);
|
||||
}
|
||||
|
||||
const char *MSDeviceInfoGetProvider(const MSDeviceInfoHandle device_info) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->GetProvider().c_str();
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetProviderDevice(MSDeviceInfoHandle device_info, const char *device) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->SetProviderDevice(device);
|
||||
}
|
||||
|
||||
const char *MSDeviceInfoGetProviderDevice(const MSDeviceInfoHandle device_info) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
return impl->GetProviderDevice().c_str();
|
||||
}
|
||||
|
||||
MSDeviceType MSDeviceInfoGetDeviceType(const MSDeviceInfoHandle device_info) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSDeviceTypeInvalid;
|
||||
}
|
||||
auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
|
||||
auto device_type = impl->GetDeviceType();
|
||||
return static_cast<MSDeviceType>(device_type);
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetEnableFP16(MSDeviceInfoHandle device_info, bool is_fp16) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
|
||||
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
|
||||
impl->SetEnableFP16(is_fp16);
|
||||
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
|
||||
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
|
||||
impl->SetEnableFP16(is_fp16);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
}
|
||||
|
||||
bool MSDeviceInfoGetEnableFP16(const MSDeviceInfoHandle device_info) {
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeCPU) {
|
||||
auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
|
||||
return impl->GetEnableFP16();
|
||||
} else if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeGPU) {
|
||||
auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
|
||||
return impl->GetEnableFP16();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature. device_type: " << device_type;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void MSDeviceInfoSetFrequency(MSDeviceInfoHandle device_info, int frequency) { // only for KirinNPU
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
|
||||
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
|
||||
impl->SetFrequency(frequency);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
}
|
||||
|
||||
int MSDeviceInfoGetFrequency(const MSDeviceInfoHandle device_info) { // only for KirinNPU
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return -1;
|
||||
}
|
||||
auto device_type = static_cast<mindspore::DeviceInfoContext *>(device_info)->GetDeviceType();
|
||||
if (static_cast<MSDeviceType>(device_type) == kMSDeviceTypeKirinNPU) {
|
||||
auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
|
||||
return impl->GetFrequency();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return -1;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,431 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/model_c.h"
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||
#include "src/cxx_api/converters.h"
|
||||
#include "src/lite_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelC {
|
||||
public:
|
||||
ModelC() : session_(nullptr), context_(nullptr) {}
|
||||
~ModelC() {
|
||||
for (auto &impl : tensor_map_) {
|
||||
delete impl.second;
|
||||
}
|
||||
}
|
||||
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context);
|
||||
Status Build(const std::string &model_path, ModelType model_type, const mindspore::Context::Data *model_context);
|
||||
Status Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
|
||||
|
||||
Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
|
||||
const MSKernelCallBackC &before, const MSKernelCallBackC &after);
|
||||
|
||||
MSTensor::Impl **GetInputs(size_t *input_num);
|
||||
MSTensor::Impl **GetOutputs(size_t *output_num);
|
||||
MSTensor::Impl *GetInputByTensorName(const std::string &name);
|
||||
MSTensor::Impl *GetOutputByTensorName(const std::string &name);
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::LiteSession> session_ = nullptr;
|
||||
std::shared_ptr<const Context::Data> context_ = nullptr;
|
||||
std::map<mindspore::tensor::MSTensor *, MSTensor::Impl *> tensor_map_;
|
||||
std::vector<MSTensor::Impl *> inputs_;
|
||||
std::vector<MSTensor::Impl *> outputs_;
|
||||
Status RunGraph(const MSKernelCallBackC &before, const MSKernelCallBackC &after);
|
||||
void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors);
|
||||
MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor);
|
||||
};
|
||||
|
||||
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
session_ = std::shared_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelC::Build(const std::string &model_path, ModelType model_type,
|
||||
const mindspore::Context::Data *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
session_ = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelC::Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) {
|
||||
std::vector<tensor::MSTensor *> inner_input;
|
||||
size_t input_num = inputs.size();
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto input = inputs[i];
|
||||
if (input == nullptr || input->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "Input tensor is null.";
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
inner_input.push_back(input->lite_tensor());
|
||||
}
|
||||
size_t shape_num = shapes.size();
|
||||
std::vector<std::vector<int32_t>> inner_shapes(shape_num);
|
||||
for (size_t i = 0; i < shape_num; i++) {
|
||||
std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(inner_shapes[i]),
|
||||
[](int64_t value) { return static_cast<int32_t>(value); });
|
||||
}
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto ret = session_->Resize(inner_input, inner_shapes);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
void ModelC::ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors) {
|
||||
for (size_t j = 0; j < old_data.size(); j++) {
|
||||
tensors.at(j)->set_data(old_data.at(j));
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelC::Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
|
||||
const MSKernelCallBackC &before, const MSKernelCallBackC &after) {
|
||||
if (outputs == nullptr || session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto model_inputs = session_->GetInputs();
|
||||
if (model_inputs.size() != input_num) {
|
||||
MS_LOG(ERROR) << "Wrong input size.";
|
||||
return kLiteError;
|
||||
}
|
||||
std::vector<void *> old_data;
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto real_input = model_inputs[i];
|
||||
auto user_input = static_cast<mindspore::MSTensor::Impl *>(inputs[i]);
|
||||
if (user_input->DataType() != static_cast<DataType>(real_input->data_type())) {
|
||||
ResetTensorData(old_data, model_inputs);
|
||||
MS_LOG(ERROR) << "DataType does not match, input:" << user_input->Name()
|
||||
<< ", real:" << real_input->tensor_name();
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
if (user_input->Data() == nullptr) {
|
||||
ResetTensorData(old_data, model_inputs);
|
||||
MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has no data.";
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
old_data.push_back(real_input->data());
|
||||
if (real_input->data_type() == kObjectTypeString) {
|
||||
std::vector<int32_t> shape;
|
||||
std::transform(user_input->Shape().begin(), user_input->Shape().end(), std::back_inserter(shape),
|
||||
[](int64_t value) { return static_cast<int32_t>(value); });
|
||||
real_input->set_shape(shape);
|
||||
real_input->set_data(user_input->MutableData());
|
||||
} else {
|
||||
if (user_input->MutableData() != real_input->data()) {
|
||||
if (real_input->Size() != user_input->DataSize()) {
|
||||
ResetTensorData(old_data, model_inputs);
|
||||
MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has wrong data size.";
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
real_input->set_data(user_input->MutableData());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto ret = RunGraph(before, after);
|
||||
ResetTensorData(old_data, model_inputs);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
*outputs = reinterpret_cast<MSTensorHandle *>(GetOutputs(output_num));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelC::RunGraph(const MSKernelCallBackC &before, const MSKernelCallBackC &after) {
|
||||
if (before == nullptr || after == nullptr) {
|
||||
auto ret = session_->RunGraph();
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) {
|
||||
std::vector<mindspore::MSTensor::Impl> inputs_impl;
|
||||
std::vector<mindspore::MSTensor::Impl> outputs_impl;
|
||||
std::vector<MSTensorHandle> op_inputs;
|
||||
std::vector<MSTensorHandle> op_outputs;
|
||||
size_t op_input_num = before_inputs.size();
|
||||
for (size_t i = 0; i < op_input_num; i++) {
|
||||
inputs_impl.emplace_back(before_inputs[i]);
|
||||
op_inputs.push_back(&(inputs_impl.back()));
|
||||
}
|
||||
size_t op_output_num = before_outputs.size();
|
||||
for (size_t i = 0; i < op_output_num; i++) {
|
||||
outputs_impl.emplace_back(before_outputs[i]);
|
||||
op_outputs.push_back(&(outputs_impl.back()));
|
||||
}
|
||||
const MSCallBackParamC op_info = {const_cast<char *>(call_param.node_name.c_str()),
|
||||
const_cast<char *>(call_param.node_type.c_str())};
|
||||
MSTensorHandleArray inputs = {op_input_num, op_inputs.data()};
|
||||
MSTensorHandleArray outputs = {op_output_num, op_outputs.data()};
|
||||
return before(inputs, outputs, op_info);
|
||||
};
|
||||
|
||||
auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const CallBackParam &call_param) {
|
||||
std::vector<mindspore::MSTensor::Impl> inputs_impl;
|
||||
std::vector<mindspore::MSTensor::Impl> outputs_impl;
|
||||
std::vector<MSTensorHandle> op_inputs;
|
||||
std::vector<MSTensorHandle> op_outputs;
|
||||
size_t op_input_num = after_inputs.size();
|
||||
for (size_t i = 0; i < op_input_num; i++) {
|
||||
inputs_impl.emplace_back(after_inputs[i]);
|
||||
op_inputs.push_back(&(inputs_impl.back()));
|
||||
}
|
||||
size_t op_output_num = after_outputs.size();
|
||||
for (size_t i = 0; i < op_output_num; i++) {
|
||||
outputs_impl.emplace_back(after_outputs[i]);
|
||||
op_outputs.push_back(&(outputs_impl.back()));
|
||||
}
|
||||
const MSCallBackParamC op_info = {const_cast<char *>(call_param.node_name.c_str()),
|
||||
const_cast<char *>(call_param.node_type.c_str())};
|
||||
MSTensorHandleArray inputs = {op_input_num, op_inputs.data()};
|
||||
MSTensorHandleArray outputs = {op_output_num, op_outputs.data()};
|
||||
return after(inputs, outputs, op_info);
|
||||
};
|
||||
auto ret = session_->RunGraph(before_call_back, after_call_back);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
MSTensor::Impl *ModelC::TensorToTensorImpl(mindspore::tensor::MSTensor *tensor) {
|
||||
MSTensor::Impl *impl = nullptr;
|
||||
auto iter = tensor_map_.find(tensor);
|
||||
if (iter != tensor_map_.end()) {
|
||||
impl = iter->second;
|
||||
} else {
|
||||
impl = new (std::nothrow) MSTensor::Impl(tensor);
|
||||
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return nullptr;
|
||||
}
|
||||
tensor_map_[tensor] = impl;
|
||||
}
|
||||
return impl;
|
||||
}
|
||||
|
||||
MSTensor::Impl **ModelC::GetInputs(size_t *input_num) {
|
||||
if (session_ == nullptr || input_num == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return nullptr;
|
||||
}
|
||||
auto inputs = session_->GetInputs();
|
||||
*input_num = inputs.size();
|
||||
if (inputs_.capacity() < *input_num) {
|
||||
inputs_.reserve(*input_num);
|
||||
}
|
||||
inputs_.clear();
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_),
|
||||
[&](tensor::MSTensor *input) { return TensorToTensorImpl(input); });
|
||||
return inputs_.data();
|
||||
}
|
||||
|
||||
MSTensor::Impl **ModelC::GetOutputs(size_t *output_num) {
|
||||
if (session_ == nullptr || output_num == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return nullptr;
|
||||
}
|
||||
auto outputs = session_->GetOutputs();
|
||||
*output_num = outputs.size();
|
||||
if (outputs_.capacity() < *output_num) {
|
||||
outputs_.reserve(*output_num);
|
||||
}
|
||||
outputs_.clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_),
|
||||
[&](std::unordered_map<std::string, mindspore::tensor::MSTensor *>::value_type iter) {
|
||||
return TensorToTensorImpl(iter.second);
|
||||
});
|
||||
return outputs_.data();
|
||||
}
|
||||
|
||||
MSTensor::Impl *ModelC::GetInputByTensorName(const std::string &name) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor = session_->GetInputsByTensorName(name);
|
||||
return TensorToTensorImpl(tensor);
|
||||
}
|
||||
|
||||
MSTensor::Impl *ModelC::GetOutputByTensorName(const std::string &name) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session is null.";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor = session_->GetOutputByTensorName(name);
|
||||
return TensorToTensorImpl(tensor);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
MSModelHandle MSModelCreate() {
|
||||
auto impl = new (std::nothrow) mindspore::ModelC();
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<MSModelHandle>(impl);
|
||||
}
|
||||
|
||||
void MSModelDestroy(MSModelHandle model) {
|
||||
if (model != nullptr) {
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
delete impl;
|
||||
}
|
||||
}
|
||||
|
||||
MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_size, MSModelType model_type,
|
||||
const MSContextHandle model_context) {
|
||||
if (model == nullptr || model_data == nullptr || model_context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSStatusLiteNullptr;
|
||||
}
|
||||
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
}
|
||||
|
||||
MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type,
|
||||
const MSContextHandle model_context) {
|
||||
if (model == nullptr || model_path == nullptr || model_context == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSStatusLiteNullptr;
|
||||
}
|
||||
mindspore::Context::Data *context = static_cast<mindspore::Context::Data *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
}
|
||||
|
||||
MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos,
|
||||
size_t shape_info_num) {
|
||||
if (model == nullptr || shape_infos == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSStatusLiteNullptr;
|
||||
}
|
||||
std::vector<mindspore::MSTensor::Impl *> vec_inputs;
|
||||
std::transform(inputs.handle_list, inputs.handle_list + inputs.handle_num, std::back_inserter(vec_inputs),
|
||||
[](MSTensorHandle value) { return static_cast<mindspore::MSTensor::Impl *>(value); });
|
||||
std::vector<std::vector<int64_t>> vec_dims;
|
||||
for (size_t i = 0; i < shape_info_num; i++) {
|
||||
std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
|
||||
vec_dims.push_back(shape);
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Resize(vec_inputs, vec_dims);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
}
|
||||
|
||||
MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs,
|
||||
const MSKernelCallBackC before, const MSKernelCallBackC after) {
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSStatusLiteNullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Predict(inputs.handle_list, inputs.handle_num, &(outputs->handle_list), &(outputs->handle_num),
|
||||
before, after);
|
||||
if (!ret.IsOk()) {
|
||||
MS_LOG(ERROR) << "Predict fail, ret :" << ret;
|
||||
}
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
}
|
||||
|
||||
MSTensorHandleArray MSModelGetInputs(const MSModelHandle model) {
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return {0, nullptr};
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
size_t input_num;
|
||||
auto handles = reinterpret_cast<MSTensorHandle *>(impl->GetInputs(&input_num));
|
||||
return {input_num, handles};
|
||||
}
|
||||
|
||||
MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model) {
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return {0, nullptr};
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
size_t output_num;
|
||||
auto handles = reinterpret_cast<MSTensorHandle *>(impl->GetOutputs(&output_num));
|
||||
return {output_num, handles};
|
||||
}
|
||||
|
||||
MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char *tensor_name) {
|
||||
if (model == nullptr || tensor_name == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
size_t input_num;
|
||||
auto inputs = impl->GetInputs(&input_num);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
if (inputs[i]->Name() == tensor_name) {
|
||||
return static_cast<MSTensorHandle>(inputs[i]);
|
||||
}
|
||||
}
|
||||
MS_LOG(ERROR) << "tensor is not exist.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const char *tensor_name) {
|
||||
if (model == nullptr || tensor_name == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
size_t output_num;
|
||||
auto outputs = impl->GetOutputs(&output_num);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
if (outputs[i]->Name() == tensor_name) {
|
||||
return static_cast<MSTensorHandle>(outputs[i]);
|
||||
}
|
||||
}
|
||||
MS_LOG(ERROR) << "tensor is not exist.";
|
||||
return nullptr;
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/tensor_c.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
|
||||
MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *shape, size_t shape_num,
|
||||
const void *data, size_t data_len) {
|
||||
if (name == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int32_t> vec_shape(shape_num);
|
||||
for (size_t i = 0; i < shape_num; i++) {
|
||||
vec_shape[i] = shape[i];
|
||||
}
|
||||
auto lite_tensor =
|
||||
mindspore::lite::Tensor::CreateTensor(name, static_cast<mindspore::TypeId>(type), vec_shape, data, data_len);
|
||||
auto impl = new (std::nothrow) mindspore::MSTensor::Impl(lite_tensor);
|
||||
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate tensor impl.";
|
||||
return nullptr;
|
||||
}
|
||||
impl->set_from_session(false);
|
||||
return impl;
|
||||
}
|
||||
|
||||
void MSTensorDestroy(MSTensorHandle tensor) {
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
if (impl != nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
}
|
||||
|
||||
MSTensorHandle MSTensorClone(MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
auto lite_tensor = static_cast<mindspore::lite::Tensor *>(impl->lite_tensor());
|
||||
auto clone = mindspore::lite::Tensor::CopyTensor(*lite_tensor, true, lite_tensor->allocator());
|
||||
if (clone == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate tensor.";
|
||||
return nullptr;
|
||||
}
|
||||
auto clone_impl = new (std::nothrow) mindspore::MSTensor::Impl(clone);
|
||||
if (clone_impl == nullptr) {
|
||||
delete clone;
|
||||
MS_LOG(ERROR) << "Failed to allocate tensor impl.";
|
||||
return nullptr;
|
||||
}
|
||||
clone_impl->set_from_session(false);
|
||||
return clone_impl;
|
||||
}
|
||||
|
||||
void MSTensorSetName(MSTensorHandle tensor, const char *name) {
|
||||
if (tensor == nullptr || name == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
impl->SetName(name);
|
||||
}
|
||||
|
||||
const char *MSTensorGetName(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->Name().c_str();
|
||||
}
|
||||
|
||||
void MSTensorSetDataType(MSTensorHandle tensor, MSDataType type) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
impl->SetDataType(static_cast<mindspore::DataType>(type));
|
||||
}
|
||||
|
||||
MSDataType MSTensorGetDataType(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSDataTypeUnknown;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
auto dtype = impl->DataType();
|
||||
return static_cast<MSDataType>(dtype);
|
||||
}
|
||||
|
||||
void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_num) {
|
||||
if (tensor == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
std::vector<int64_t> vec_shape(shape_num);
|
||||
for (size_t i = 0; i < shape_num; i++) {
|
||||
vec_shape[i] = shape[i];
|
||||
}
|
||||
impl->SetShape(vec_shape);
|
||||
}
|
||||
|
||||
const int64_t *MSTensorGetShape(const MSTensorHandle tensor, size_t *shape_num) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
*shape_num = impl->Shape().size();
|
||||
return impl->Shape().data();
|
||||
}
|
||||
|
||||
void MSTensorSetFormat(MSTensorHandle tensor, MSFormat format) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->SetFormat(static_cast<mindspore::Format>(format));
|
||||
}
|
||||
|
||||
MSFormat MSTensorGetFormat(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return kMSFormatNHWC;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return static_cast<MSFormat>(impl->format());
|
||||
}
|
||||
|
||||
void MSTensorSetData(MSTensorHandle tensor, void *data) {
|
||||
if (tensor == nullptr || data == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->SetData(data);
|
||||
}
|
||||
|
||||
const void *MSTensorGetData(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->Data().get();
|
||||
}
|
||||
|
||||
void *MSTensorGetMutableData(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->MutableData();
|
||||
}
|
||||
|
||||
int64_t MSTensorGetElementNum(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->ElementNum();
|
||||
}
|
||||
|
||||
size_t MSTensorGetDataSize(const MSTensorHandle tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
auto impl = static_cast<mindspore::MSTensor::Impl *>(tensor);
|
||||
return impl->DataSize();
|
||||
}
|
|
@ -13,15 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "include/api/context.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#ifndef SUPPORT_NNIE
|
||||
#include <any>
|
||||
#else
|
||||
#include <experimental/any>
|
||||
#endif
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
|
@ -53,24 +47,6 @@ constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.asce
|
|||
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
|
||||
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
|
||||
|
||||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
int32_t thread_num = 2;
|
||||
bool enable_parallel_ = false;
|
||||
std::vector<int32_t> affinity_core_list_;
|
||||
int affinity_mode_ = 0;
|
||||
std::shared_ptr<Delegate> delegate = nullptr;
|
||||
};
|
||||
|
||||
struct DeviceInfoContext::Data {
|
||||
#ifndef SUPPORT_NNIE
|
||||
std::map<std::string, std::any> params;
|
||||
#else
|
||||
std::map<std::string, std::experimental::any> params;
|
||||
#endif
|
||||
std::shared_ptr<Allocator> allocator = nullptr;
|
||||
};
|
||||
|
||||
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||
|
||||
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_CONTEXT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#ifndef SUPPORT_NNIE
|
||||
#include <any>
|
||||
#else
|
||||
#include <experimental/any>
|
||||
#endif
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
int32_t thread_num = 2;
|
||||
bool enable_parallel_ = false;
|
||||
std::vector<int32_t> affinity_core_list_;
|
||||
int affinity_mode_ = 0;
|
||||
std::shared_ptr<Delegate> delegate = nullptr;
|
||||
};
|
||||
|
||||
struct DeviceInfoContext::Data {
|
||||
#ifndef SUPPORT_NNIE
|
||||
std::map<std::string, std::any> params;
|
||||
#else
|
||||
std::map<std::string, std::experimental::any> params;
|
||||
#endif
|
||||
std::shared_ptr<Allocator> allocator = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_CONTEXT_H_
|
|
@ -130,4 +130,71 @@ Status A2L_ConvertContext(Context *a_context, lite::InnerContext *l_context) {
|
|||
l_context->delegate = a_context->GetDelegate();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context) {
|
||||
if ((a_context == nullptr) || (l_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto device_list = a_context->device_info_list;
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (device_list.size() > kMaxNumOfDevices) {
|
||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
l_context->thread_num_ = a_context->thread_num;
|
||||
l_context->enable_parallel_ = a_context->enable_parallel_;
|
||||
l_context->affinity_core_list_ = a_context->affinity_core_list_;
|
||||
l_context->device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||
l_context->allocator = cpu_context->GetAllocator();
|
||||
if (l_context->allocator == nullptr) {
|
||||
l_context->allocator = Allocator::Create();
|
||||
if (l_context->allocator == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
cpu_context->SetAllocator(l_context->allocator);
|
||||
}
|
||||
|
||||
if (!IsAffinityModeValid(a_context->affinity_mode_)) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode_);
|
||||
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
if (device_list.size() == kMaxNumOfDevices) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
l_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
}
|
||||
l_context->delegate = a_context->delegate;
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -61,7 +62,7 @@ inline bool IsAffinityModeValid(int affinity_mode) {
|
|||
}
|
||||
|
||||
Status A2L_ConvertContext(Context *a_context, lite::InnerContext *l_context);
|
||||
|
||||
Status A2L_ConvertContext(const Context::Data *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ file(GLOB_RECURSE TEST_UT_SRC
|
|||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
|
||||
${TEST_DIR}/ut/src/api/context_c_test.cc
|
||||
${TEST_DIR}/ut/src/api/tensor_c_test.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_RUNTIME_CONVERT)
|
||||
|
|
|
@ -90,3 +90,7 @@ echo 'Optimize Allocator'
|
|||
|
||||
echo 'Runtime config file test'
|
||||
./lite-test --gtest_filter="MixDataTypeTest.Config1"
|
||||
|
||||
echo 'run c api ut test'
|
||||
./lite-test --gtest_filter="TensorCTest.*"
|
||||
./lite-test --gtest_filter="ContextCTest.*"
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/context_c.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ContextCTest : public mindspore::CommonTest {
|
||||
public:
|
||||
ContextCTest() {}
|
||||
};
|
||||
|
||||
TEST_F(ContextCTest, common_test) {
|
||||
MSDeviceInfoHandle npu_device_info = MSDeviceInfoCreate(kMSDeviceTypeKirinNPU);
|
||||
ASSERT_TRUE(npu_device_info != nullptr);
|
||||
ASSERT_EQ(MSDeviceInfoGetDeviceType(npu_device_info), kMSDeviceTypeKirinNPU);
|
||||
|
||||
MSDeviceInfoSetProvider(npu_device_info, "vendor name");
|
||||
ASSERT_STREQ(MSDeviceInfoGetProvider(npu_device_info), "vendor name");
|
||||
|
||||
MSDeviceInfoSetProviderDevice(npu_device_info, "npu_a");
|
||||
ASSERT_STREQ(MSDeviceInfoGetProviderDevice(npu_device_info), "npu_a");
|
||||
|
||||
MSDeviceInfoSetFrequency(npu_device_info, 3);
|
||||
ASSERT_EQ(MSDeviceInfoGetFrequency(npu_device_info), 3);
|
||||
|
||||
MSContextHandle context = MSContextCreate();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
|
||||
MSContextSetThreadNum(context, 4);
|
||||
ASSERT_EQ(MSContextGetThreadNum(context), 4);
|
||||
|
||||
MSContextSetThreadAffinityMode(context, 2);
|
||||
ASSERT_EQ(MSContextGetThreadAffinityMode(context), 2);
|
||||
|
||||
constexpr size_t core_num = 4;
|
||||
int32_t core_list[core_num] = {1, 3, 2, 0};
|
||||
MSContextSetThreadAffinityCoreList(context, core_list, core_num);
|
||||
size_t ret_core_num;
|
||||
const int32_t *ret_core_list = nullptr;
|
||||
ret_core_list = MSContextGetThreadAffinityCoreList(context, &ret_core_num);
|
||||
ASSERT_EQ(ret_core_num, core_num);
|
||||
for (size_t i = 0; i < ret_core_num; i++) {
|
||||
ASSERT_EQ(ret_core_list[i], core_list[i]);
|
||||
}
|
||||
|
||||
MSContextSetEnableParallel(context, true);
|
||||
ASSERT_EQ(MSContextGetEnableParallel(context), true);
|
||||
|
||||
MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
|
||||
MSDeviceInfoDestroy(cpu_device_info);
|
||||
cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
|
||||
|
||||
MSDeviceInfoSetEnableFP16(cpu_device_info, true);
|
||||
ASSERT_EQ(MSDeviceInfoGetEnableFP16(cpu_device_info), true);
|
||||
|
||||
MSContextAddDeviceInfo(context, cpu_device_info);
|
||||
MSContextAddDeviceInfo(context, npu_device_info);
|
||||
MSContextDestroy(context);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/c_api/tensor_c.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TensorCTest : public mindspore::CommonTest {
|
||||
public:
|
||||
TensorCTest() {}
|
||||
};
|
||||
|
||||
TEST_F(TensorCTest, common_test) {
|
||||
constexpr size_t shape_num = 2;
|
||||
int64_t shape[shape_num] = {2, 3};
|
||||
MSTensorHandle tensor = MSTensorCreate("name001", kMSDataTypeNumberTypeInt32, shape, shape_num, nullptr, 0);
|
||||
ASSERT_TRUE(tensor != nullptr);
|
||||
ASSERT_STREQ(MSTensorGetName(tensor), "name001");
|
||||
ASSERT_EQ(MSTensorGetDataType(tensor), kMSDataTypeNumberTypeInt32);
|
||||
size_t ret_shape_num;
|
||||
const int64_t *ret_shape = MSTensorGetShape(tensor, &ret_shape_num);
|
||||
ASSERT_EQ(ret_shape_num, shape_num);
|
||||
for (size_t i = 0; i < ret_shape_num; i++) {
|
||||
ASSERT_EQ(ret_shape[i], shape[i]);
|
||||
}
|
||||
ASSERT_EQ(MSTensorGetElementNum(tensor), 6);
|
||||
ASSERT_EQ(MSTensorGetDataSize(tensor), 6 * sizeof(int32_t));
|
||||
ASSERT_EQ(MSTensorGetData(tensor), nullptr);
|
||||
ASSERT_TRUE(MSTensorGetMutableData(tensor) != nullptr);
|
||||
|
||||
MSTensorSetName(tensor, "name002");
|
||||
ASSERT_STREQ(MSTensorGetName(tensor), "name002");
|
||||
|
||||
MSTensorSetDataType(tensor, kMSDataTypeNumberTypeFloat32);
|
||||
ASSERT_EQ(MSTensorGetDataType(tensor), kMSDataTypeNumberTypeFloat32);
|
||||
constexpr size_t new_shape_num = 4;
|
||||
int64_t new_shape[new_shape_num] = {1, 2, 3, 1};
|
||||
MSTensorSetShape(tensor, new_shape, new_shape_num);
|
||||
size_t new_ret_shape_num;
|
||||
const int64_t *new_ret_shape = MSTensorGetShape(tensor, &new_ret_shape_num);
|
||||
ASSERT_EQ(new_ret_shape_num, new_shape_num);
|
||||
for (size_t i = 0; i < new_ret_shape_num; i++) {
|
||||
ASSERT_EQ(new_ret_shape[i], new_shape[i]);
|
||||
}
|
||||
|
||||
MSTensorSetFormat(tensor, kMSFormatNCHW);
|
||||
ASSERT_EQ(MSTensorGetFormat(tensor), kMSFormatNCHW);
|
||||
|
||||
constexpr size_t data_len = 6;
|
||||
ASSERT_EQ(MSTensorGetElementNum(tensor), data_len);
|
||||
ASSERT_EQ(MSTensorGetDataSize(tensor), data_len * sizeof(float));
|
||||
|
||||
float data[data_len] = {1, 2, 3, 4, 5, 6};
|
||||
MSTensorSetData(tensor, data);
|
||||
const float *ret_data = static_cast<const float *>(MSTensorGetData(tensor));
|
||||
for (size_t i = 0; i < data_len; i++) {
|
||||
ASSERT_EQ(ret_data[i], data[i]);
|
||||
}
|
||||
|
||||
MSTensorHandle clone = MSTensorClone(tensor);
|
||||
ASSERT_TRUE(clone != nullptr);
|
||||
ASSERT_STREQ(MSTensorGetName(clone), "");
|
||||
ASSERT_EQ(MSTensorGetDataType(clone), kMSDataTypeNumberTypeFloat32);
|
||||
size_t clone_shape_num;
|
||||
const int64_t *clone_shape = MSTensorGetShape(clone, &clone_shape_num);
|
||||
ASSERT_EQ(clone_shape_num, new_ret_shape_num);
|
||||
for (size_t i = 0; i < clone_shape_num; i++) {
|
||||
ASSERT_EQ(clone_shape[i], new_ret_shape[i]);
|
||||
}
|
||||
ASSERT_EQ(MSTensorGetElementNum(clone), MSTensorGetElementNum(tensor));
|
||||
ASSERT_EQ(MSTensorGetDataSize(clone), MSTensorGetDataSize(tensor));
|
||||
ASSERT_TRUE(MSTensorGetData(clone) != MSTensorGetData(tensor));
|
||||
|
||||
MSTensorDestroy(tensor);
|
||||
MSTensorDestroy(clone);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -88,9 +88,9 @@ FMT_FILE_LIST='__format_files_list__'
|
|||
if [[ "X${mode}" == "Xall" ]]; then
|
||||
find mindspore/{ccsrc,core,lite} -type f -name "*" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true
|
||||
elif [[ "X${mode}" == "Xchanged" ]]; then
|
||||
git diff --name-only | grep "mindspore/ccsrc\|mindspore/core\|mindspore/lite" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true
|
||||
git diff --name-only | grep "mindspore/ccsrc\|mindspore/core\|mindspore/lite\|include" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true
|
||||
else # "X${mode}" == "Xlastcommit"
|
||||
git diff --name-only HEAD~ HEAD | grep "mindspore/ccsrc\|mindspore/core\|mindspore/lite" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true
|
||||
git diff --name-only HEAD~ HEAD | grep "mindspore/ccsrc\|mindspore/core\|mindspore/lite\|include" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true
|
||||
fi
|
||||
|
||||
while read line; do
|
||||
|
|
Loading…
Reference in New Issue