forked from mindspore-Ecosystem/mindspore
c api support custom op
This commit is contained in:
parent
46ded90de6
commit
5c16bef816
|
@ -24,6 +24,10 @@ else()
|
|||
target_link_libraries(mindspore_c_api PRIVATE ${PYTHON_LIBRARIES})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
add_compile_definitions(BUILDING_C_API_DLL)
|
||||
endif()
|
||||
|
||||
if(ENABLE_D OR ENABLE_GPU)
|
||||
target_link_libraries(mindspore_c_api PRIVATE ${SECUREC_LIBRARY} mindspore_backend mindspore_core
|
||||
mindspore_common proto_input mindspore::protobuf)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_C_API_BASE_MACROS_H_
|
||||
|
||||
#if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__))
|
||||
#ifdef BUILDING_CORE_DLL
|
||||
#ifdef BUILDING_C_API_DLL
|
||||
#define MIND_C_API __declspec(dllexport)
|
||||
#else
|
||||
#define MIND_C_API __declspec(dllimport)
|
||||
|
@ -27,4 +27,6 @@
|
|||
#define MIND_C_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#define MAX_DIMS 8
|
||||
|
||||
#endif // MINDSPORE_CCSRC_C_API_BASE_MACROS_H_
|
||||
|
|
|
@ -17,94 +17,228 @@
|
|||
#ifndef MINDSPORE_CCSRC_C_API_BASE_TYPES_H_
|
||||
#define MINDSPORE_CCSRC_C_API_BASE_TYPES_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// \brief TypeId defines data type identifiers.
|
||||
typedef enum TypeId {
|
||||
kTypeUnknown = 0,
|
||||
//
|
||||
// Meta types.
|
||||
//
|
||||
kMetaTypeBegin = kTypeUnknown,
|
||||
kMetaTypeType, // Type
|
||||
kMetaTypeAnything,
|
||||
kMetaTypeObject,
|
||||
kMetaTypeTypeType, // TypeType
|
||||
kMetaTypeProblem,
|
||||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
//
|
||||
kObjectTypeBegin = kMetaTypeEnd,
|
||||
kObjectTypeNumber,
|
||||
kObjectTypeString,
|
||||
kObjectTypeList,
|
||||
kObjectTypeTuple,
|
||||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeCOOTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
kObjectTypeJTagged,
|
||||
kObjectTypeSymbolicKeyType,
|
||||
kObjectTypeEnvType,
|
||||
kObjectTypeRefKey,
|
||||
kObjectTypeRef,
|
||||
kObjectTypeEnd,
|
||||
//
|
||||
// Number Types
|
||||
//
|
||||
kNumberTypeBegin = kObjectTypeEnd,
|
||||
kNumberTypeBool,
|
||||
kNumberTypeInt,
|
||||
kNumberTypeInt8,
|
||||
kNumberTypeInt16,
|
||||
kNumberTypeInt32,
|
||||
kNumberTypeInt64,
|
||||
kNumberTypeUInt,
|
||||
kNumberTypeUInt8,
|
||||
kNumberTypeUInt16,
|
||||
kNumberTypeUInt32,
|
||||
kNumberTypeUInt64,
|
||||
kNumberTypeFloat,
|
||||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeComplex128,
|
||||
kNumberTypeInt4,
|
||||
kNumberTypeGLUInt,
|
||||
kNumberTypeEnd,
|
||||
//
|
||||
// Monad Types
|
||||
//
|
||||
kMonadTypeBegin = kNumberTypeEnd,
|
||||
kObjectTypeMonad,
|
||||
kObjectTypeUMonad,
|
||||
kObjectTypeIOMonad,
|
||||
kMonadTypeEnd,
|
||||
//
|
||||
// Sparse Types
|
||||
//
|
||||
kSparseTypeBegin = kMonadTypeEnd,
|
||||
kObjectTypeCSRTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeMapTensorType,
|
||||
kSparseTypeEnd,
|
||||
// New types should placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
} TypeId;
|
||||
typedef enum DataTypeC {
|
||||
MS_NONE = 0,
|
||||
MS_BOOL = 30,
|
||||
MS_INT8 = 32,
|
||||
MS_INT16 = 33,
|
||||
MS_INT32 = 34,
|
||||
MS_INT64 = 35,
|
||||
MS_UINT8 = 37,
|
||||
MS_UINT16 = 38,
|
||||
MS_UINT32 = 39,
|
||||
MS_UINT64 = 40,
|
||||
MS_FLOAT16 = 42,
|
||||
MS_FLOAT32 = 43,
|
||||
MS_FLOAT64 = 44,
|
||||
MS_COMPLEX64 = 46,
|
||||
MS_COMPLEX128 = 47,
|
||||
MS_INVALID_TYPE = INT32_MAX,
|
||||
} DataTypeC;
|
||||
|
||||
typedef enum DTypeFormat {
|
||||
None_None, // {"", ""};
|
||||
None_Default, // {"", "DefaultFormat"};
|
||||
|
||||
BOOL_None, // {"bool", ""};
|
||||
BOOL_Default, // {"bool", "DefaultFormat"};
|
||||
BOOL_5HD, // {"bool", "NC1HWC0"};
|
||||
BOOL_FracZ, // {"bool", "FRACTAL_Z"};
|
||||
BOOL_FracNZ, // {"bool", "FRACTAL_NZ"};
|
||||
BOOL_C1HWNCoC0, // {"bool", "C1HWNCoC0"};
|
||||
BOOL_NCHW, // {"bool", "NCHW"};
|
||||
BOOL_NHWC, // {"bool", "NHWC"};
|
||||
BOOL_HWCN, // {"bool", "HWCN"};
|
||||
BOOL_NDHWC, // {"bool", "NDHWC"};
|
||||
BOOL_ChannelLast, // {"bool", "ChannelLast"};
|
||||
BOOL_Default_Tuple, // {"bool", "DefaultFormat", "tuple"};
|
||||
BOOL_Default_List, // {"bool", "DefaultFormat", "list"};
|
||||
|
||||
I8_None, // {"int8", ""};
|
||||
I8_Default, // {"int8", "DefaultFormat"};
|
||||
I8_5HD, // {"int8", "NC1HWC0"};
|
||||
I8_FracZ, // {"int8", "FRACTAL_Z"};
|
||||
I8_FracNZ, // {"int8", "FRACTAL_NZ"};
|
||||
I8_C1HWNCoC0, // {"int8", "C1HWNCoC0"};
|
||||
I8_NCHW, // {"int8", "NCHW"};
|
||||
I8_NHWC, // {"int8", "NHWC"};
|
||||
I8_HWCN, // {"int8", "HWCN"};
|
||||
I8_NDHWC, // {"int8", "NDHWC"};
|
||||
I8_NCDHW, // {"int8", "NCDHW"};
|
||||
I8_ChannelLast, // {"int8", "ChannelLast"};
|
||||
I8_NDC1HWC0, // {"int8", "NDC1HWC0"};
|
||||
I8_NC1HWC0, // {"int8", "NC1HWC0"};
|
||||
I8_Default_Tuple, // {"int8", "DefaultFormat", "tuple"};
|
||||
I8_Default_List, // {"int8", "DefaultFormat", "list"};
|
||||
|
||||
U8_None, // {"uint8", ""};
|
||||
U8_Default, // {"uint8", "DefaultFormat"};
|
||||
U8_5HD, // {"uint8", "NC1HWC0"};
|
||||
U8_FracZ, // {"uint8", "FRACTAL_Z"};
|
||||
U8_FracNZ, // {"uint8", "FRACTAL_NZ"};
|
||||
U8_C1HWNCoC0, // {"uint8", "C1HWNCoC0"};
|
||||
U8_NCHW, // {"uint8", "NCHW"};
|
||||
U8_NHWC, // {"uint8", "NHWC"};
|
||||
U8_HWCN, // {"uint8", "HWCN"};
|
||||
U8_NDHWC, // {"uint8", "NDHWC"};
|
||||
U8_NCDHW, // {"uint8", "NCDHW"};
|
||||
U8_ChannelLast, // {"uint8", "ChannelLast"};
|
||||
U8_NDC1HWC0, // {"uint8", "NDC1HWC0"};
|
||||
U8_NC1HWC0, // {"uint8", "NC1HWC0"};
|
||||
U8_Default_Tuple, // {"uint8", "DefaultFormat", "tuple"};
|
||||
U8_Default_List, // {"uint8", "DefaultFormat", "list"};
|
||||
|
||||
I16_None, // {"int16", ""};
|
||||
I16_Default, // {"int16", "DefaultFormat"};
|
||||
I16_5HD, // {"int16", "NC1HWC0"};
|
||||
I16_FracZ, // {"int16", "FRACTAL_Z"};
|
||||
I16_FracNZ, // {"int16", "FRACTAL_NZ"};
|
||||
I16_C1HWNCoC0, // {"int16", "C1HWNCoC0"};
|
||||
I16_NCHW, // {"int16", "NCHW"};
|
||||
I16_NHWC, // {"int16", "NHWC"};
|
||||
I16_HWCN, // {"int16", "HWCN"};
|
||||
I16_NDHWC, // {"int16", "NDHWC"};
|
||||
I16_ChannelLast, // {"int16", "ChannelLast"};
|
||||
I16_Default_Tuple, // {"int16", "DefaultFormat", "tuple"};
|
||||
I16_Default_List, // {"int16", "DefaultFormat", "list"};
|
||||
|
||||
U16_None, // {"uint16", ""};
|
||||
U16_Default, // {"uint16", "DefaultFormat"};
|
||||
U16_5HD, // {"uint16", "NC1HWC0"};
|
||||
U16_FracZ, // {"uint16", "FRACTAL_Z"};
|
||||
U16_FracNZ, // {"uint16", "FRACTAL_NZ"};
|
||||
U16_C1HWNCoC0, // {"uint16", "C1HWNCoC0"};
|
||||
U16_NCHW, // {"uint16", "NCHW"};
|
||||
U16_NHWC, // {"uint16", "NHWC"};
|
||||
U16_HWCN, // {"uint16", "HWCN"};
|
||||
U16_NDHWC, // {"uint16", "NDHWC"};
|
||||
U16_ChannelLast, // {"uint16", "ChannelLast"};
|
||||
U16_Default_Tuple, // {"uint16", "DefaultFormat", "tuple"};
|
||||
U16_Default_List, // {"uint16", "DefaultFormat", "list"};
|
||||
|
||||
I32_None, // {"int32", ""};
|
||||
I32_Default, // {"int32", "DefaultFormat"};
|
||||
I32_5HD, // {"int32", "NC1HWC0"};
|
||||
I32_FracZ, // {"int32", "FRACTAL_Z"};
|
||||
I32_FracNZ, // {"int32", "FRACTAL_NZ"};
|
||||
I32_C1HWNCoC0, // {"int32", "C1HWNCoC0"};
|
||||
I32_NCHW, // {"int32", "NCHW"};
|
||||
I32_NHWC, // {"int32", "NHWC"};
|
||||
I32_HWCN, // {"int32", "HWCN"};
|
||||
I32_NDHWC, // {"int32", "NDHWC"};
|
||||
I32_NDC1HWC0, // {"int32", "NDC1HWC0"};
|
||||
I32_NCDHW, // {"int32", "NCDHW"};
|
||||
I32_ChannelLast, // {"int32", "ChannelLast"};
|
||||
I32_Default_Tuple, // {"int32", "DefaultFormat", "tuple"};
|
||||
I32_Default_List, // {"int32", "DefaultFormat", "list"};
|
||||
|
||||
U32_None, // {"uint32", ""};
|
||||
U32_Default, // {"uint32", "DefaultFormat"};
|
||||
U32_5HD, // {"uint32", "NC1HWC0"};
|
||||
U32_FracZ, // {"uint32", "FRACTAL_Z"};
|
||||
U32_FracNZ, // {"uint32", "FRACTAL_NZ"};
|
||||
U32_C1HWNCoC0, // {"uint32", "C1HWNCoC0"};
|
||||
U32_NCHW, // {"uint32", "NCHW"};
|
||||
U32_NHWC, // {"uint32", "NHWC"};
|
||||
U32_HWCN, // {"uint32", "HWCN"};
|
||||
U32_NDHWC, // {"uint32", "NDHWC"};
|
||||
U32_ChannelLast, // {"uint32", "ChannelLast"};
|
||||
U32_Default_Tuple, // {"uint32", "DefaultFormat", "tuple"};
|
||||
U32_Default_List, // {"uint32", "DefaultFormat", "list"};
|
||||
|
||||
I64_None, // {"int64", ""};
|
||||
I64_Default, // {"int64", "DefaultFormat"};
|
||||
I64_5HD, // {"int64", "NC1HWC0"};
|
||||
I64_FracZ, // {"int64", "FRACTAL_Z"};
|
||||
I64_FracNZ, // {"int64", "FRACTAL_NZ"};
|
||||
I64_C1HWNCoC0, // {"int64", "C1HWNCoC0"};
|
||||
I64_NCHW, // {"int64", "NCHW"};
|
||||
I64_NHWC, // {"int64", "NHWC"};
|
||||
I64_HWCN, // {"int64", "HWCN"};
|
||||
I64_NDHWC, // {"int64", "NDHWC"};
|
||||
I64_ChannelLast, // {"int64", "ChannelLast"};
|
||||
I64_Default_Tuple, // {"int64", "DefaultFormat", "tuple"};
|
||||
I64_Default_List, // {"int64", "DefaultFormat", "list"};
|
||||
|
||||
U64_None, // {"uint64", ""};
|
||||
U64_Default, // {"uint64", "DefaultFormat"};
|
||||
U64_5HD, // {"uint64", "NC1HWC0"};
|
||||
U64_FracZ, // {"uint64", "FRACTAL_Z"};
|
||||
U64_FracNZ, // {"uint64", "FRACTAL_NZ"};
|
||||
U64_C1HWNCoC0, // {"uint64", "C1HWNCoC0"};
|
||||
U64_NCHW, // {"uint64", "NCHW"};
|
||||
U64_NHWC, // {"uint64", "NHWC"};
|
||||
U64_HWCN, // {"uint64", "HWCN"};
|
||||
U64_NDHWC, // {"uint64", "NDHWC"};
|
||||
U64_ChannelLast, // {"uint64", "ChannelLast"};
|
||||
U64_Default_Tuple, // {"uint64", "DefaultFormat", "tuple"};
|
||||
U64_Default_List, // {"uint64", "DefaultFormat", "list"};
|
||||
|
||||
F16_None, // {"float16", ""};
|
||||
F16_Default, // {"float16", "DefaultFormat"};
|
||||
F16_5HD, // {"float16", "NC1HWC0"};
|
||||
F16_FracZ, // {"float16", "FRACTAL_Z"};
|
||||
F16_FracNZ, // {"float16", "FRACTAL_NZ"};
|
||||
F16_C1HWNCoC0, // {"float16", "C1HWNCoC0"};
|
||||
F16_NCHW, // {"float16", "NCHW"};
|
||||
F16_NHWC, // {"float16", "NHWC"};
|
||||
F16_HWCN, // {"float16", "HWCN"};
|
||||
F16_NDHWC, // {"float16", "NDHWC"};
|
||||
F16_NCDHW, // {"float16", "NCDHW"};
|
||||
F16_DHWCN, // {"float16", "DHWCN"};
|
||||
F16_NDC1HWC0, // {"float16", "NDC1HWC0"};
|
||||
F16_FRACTAL_Z_3D, // {"float16", "FRACTAL_Z_3D"};
|
||||
F16_FracZNLSTM, // {"float16", "FRACTAL_ZN_LSTM"};
|
||||
F16_FracZNRNN, // {"float16", "FRACTAL_ZN_RNN"};
|
||||
F16_ND_RNNBIAS, // {"float16", "ND_RNN_BIAS"};
|
||||
F16_ChannelLast, // {"float16", "ChannelLast"};
|
||||
F16_Default_Tuple, // {"float16", "DefaultFormat", "tuple"};
|
||||
F16_Default_List, // {"float16", "DefaultFormat", "list"};
|
||||
|
||||
F32_None, // {"float32", ""};
|
||||
F32_Default, // {"float32", "DefaultFormat"};
|
||||
F32_5HD, // {"float32", "NC1HWC0"};
|
||||
F32_FracZ, // {"float32", "FRACTAL_Z"};
|
||||
F32_FracNZ, // {"float32", "FRACTAL_NZ"};
|
||||
F32_C1HWNCoC0, // {"float32", "C1HWNCoC0"};
|
||||
F32_NCHW, // {"float32", "NCHW"};
|
||||
F32_NHWC, // {"float32", "NHWC"};
|
||||
F32_HWCN, // {"float32", "HWCN"};
|
||||
F32_NDHWC, // {"float32", "NDHWC"};
|
||||
F32_NCDHW, // {"float32", "NCDHW"};
|
||||
F32_DHWCN, // {"float32", "DHWCN"};
|
||||
F32_NDC1HWC0, // {"float32", "NDC1HWC0"};
|
||||
F32_FRACTAL_Z_3D, // {"float32", "FRACTAL_Z_3D"};
|
||||
F32_FracZNLSTM, // {"float32", "FRACTAL_ZN_LSTM"};
|
||||
F32_FracZNRNN, // {"float32", "FRACTAL_ZN_RNN"};
|
||||
F32_ND_RNNBIAS, // {"float32", "ND_RNN_BIAS"};
|
||||
F32_ChannelLast, // {"float32", "ChannelLast"};
|
||||
F32_Default_Tuple, // {"float32", "DefaultFormat", "tuple"};
|
||||
F32_Default_List, // {"float32", "DefaultFormat", "list"};
|
||||
|
||||
F64_None, // {"float64", ""};
|
||||
F64_Default, // {"float64", "DefaultFormat"};
|
||||
F64_5HD, // {"float64", "NC1HWC0"};
|
||||
F64_FracZ, // {"float64", "FRACTAL_Z"};
|
||||
F64_FracNZ, // {"float64", "FRACTAL_NZ"};
|
||||
F64_C1HWNCoC0, // {"float64", "C1HWNCoC0"};
|
||||
F64_NCHW, // {"float64", "NCHW"};
|
||||
F64_NHWC, // {"float64", "NHWC"};
|
||||
F64_HWCN, // {"float64", "HWCN"};
|
||||
F64_NDHWC, // {"float64", "NDHWC"};
|
||||
F64_ChannelLast, // {"float64", "ChannelLast"};
|
||||
F64_Default_Tuple, // {"float64", "DefaultFormat", "tuple"};
|
||||
F64_Default_List, // {"float64", "DefaultFormat", "list"};
|
||||
|
||||
C64_Default, // {"complex64", "DefaultFormat"};
|
||||
C128_Default, // {"complex128", "DefaultFormat"};
|
||||
} DTypeFormat;
|
||||
|
||||
typedef enum PadMode {
|
||||
PAD = 0,
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_ABSTRACT_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_ABSTRACT_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_ABSTRACT_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_ABSTRACT_H_
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "c_api/base/macros.h"
|
||||
|
@ -46,7 +46,7 @@ MIND_C_API STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, Co
|
|||
/// \param[in] shape_size The size of the shape array, i.e., the dimension of node output.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId type, const int64_t shape[],
|
||||
MIND_C_API STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size);
|
||||
|
||||
/// \brief Get multiple Abstract to the node. Usually used in the case that the node has multiple outputs.
|
||||
|
@ -58,10 +58,10 @@ MIND_C_API STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId ty
|
|||
/// \param[in] shape_sizes The array contains the size of all shape, i.e., the dimension of all node output.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSSetMultiAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId type, const int64_t **shapes,
|
||||
MIND_C_API STATUS MSSetMultiAbstract(ResMgrHandle res_mgr, NodeHandle node, DataTypeC type, const int64_t **shapes,
|
||||
const size_t shape_sizes[], size_t abs_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_ABSTRACT_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_ABSTRACT_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_ATTRIBUTE_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_ATTRIBUTE_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_ATTRIBUTE_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_ATTRIBUTE_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -76,7 +76,7 @@ MIND_C_API STATUS MSOpSetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, co
|
|||
/// \param[in] value The input value of the attribute.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, TypeId value);
|
||||
MIND_C_API STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value);
|
||||
|
||||
/// \brief Set the attribute of the target node with the given name and value.
|
||||
///
|
||||
|
@ -87,7 +87,7 @@ MIND_C_API STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const cha
|
|||
/// \param[in] vec_size number of elements in the array.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, TypeId value[],
|
||||
MIND_C_API STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value[],
|
||||
size_t vec_size);
|
||||
|
||||
/// \brief Set the attribute of the target node with the given name and value.
|
||||
|
@ -97,12 +97,12 @@ MIND_C_API STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, cons
|
|||
/// \param[in] attr_name The attribute name associates with the node.
|
||||
/// \param[in] value The input value array of the attribute.
|
||||
/// \param[in] vec_size number of elements in the array.
|
||||
/// \param[in] dataType Data type id. Currently support kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat32,
|
||||
/// \param[in] data_type Data type id. Currently support kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat32,
|
||||
/// kNumberTypeBool.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSOpSetAttrArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, void *value,
|
||||
size_t vec_size, TypeId dataType);
|
||||
size_t vec_size, DataTypeC data_type);
|
||||
|
||||
/// \brief Set the attribute of the target node with the given name and value as ValueList.
|
||||
///
|
||||
|
@ -178,11 +178,11 @@ MIND_C_API AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v);
|
|||
/// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
|
||||
/// \param[in] value Given array.
|
||||
/// \param[in] vec_size Given array size.
|
||||
/// \param[in] dataType Datatype of the array.
|
||||
/// \param[in] data_type Datatype of the array.
|
||||
///
|
||||
/// \return Attribute value handle
|
||||
MIND_C_API AttrHandle MSOpNewAttrs(ResMgrHandle res_mgr, void *value, size_t vec_size, TypeId data_type);
|
||||
MIND_C_API AttrHandle MSNewAttrArray(ResMgrHandle res_mgr, void *value, size_t vec_size, DataTypeC data_type);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_ATTRIBUTE_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_ATTRIBUTE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_CONTEXT_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_CONTEXT_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -88,4 +88,4 @@ MIND_C_API bool MSGetInfer(ResMgrHandle res_mgr);
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_CONTEXT_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_CONTEXT_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_GRAPH_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_GRAPH_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_GRAPH_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_GRAPH_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -150,4 +150,4 @@ MIND_C_API STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, Tensor
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_GRAPH_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_GRAPH_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_NODE_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_NODE_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -30,6 +30,32 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// \brief The struct to describe custom op's basic info. For output_dtypes and dtype_infer_func, only one of them need
|
||||
/// to be specified. For output_shapes and shape_infer_func, only one of them need to be specified as well, and
|
||||
/// output_dims must be given if output_shapes is specified.
|
||||
typedef struct CustomOpInfo {
|
||||
char *func_name;
|
||||
char *func_type;
|
||||
char *target;
|
||||
char **input_name;
|
||||
size_t input_num;
|
||||
char **output_name;
|
||||
size_t output_num;
|
||||
char **attr_name;
|
||||
AttrHandle *attr_value;
|
||||
size_t attr_num;
|
||||
DTypeFormat **dtype_formats;
|
||||
size_t dtype_formats_num;
|
||||
int64_t **output_shapes;
|
||||
size_t *output_dims;
|
||||
DataTypeC *output_dtypes;
|
||||
STATUS(*dtype_infer_func)
|
||||
(const DataTypeC *input_types, size_t input_num, DataTypeC *output_types, size_t output_num);
|
||||
STATUS(*shape_infer_func)
|
||||
(int64_t **input_shapes, const size_t *input_dims, size_t input_num, int64_t **output_shapes, size_t *output_dims,
|
||||
size_t output_num);
|
||||
} CustomOpInfo;
|
||||
|
||||
/// \brief Create a new Operator node.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
|
@ -89,6 +115,18 @@ MIND_C_API NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handl
|
|||
MIND_C_API NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
|
||||
GraphHandle after_graph);
|
||||
|
||||
/// \brief Create a custom operator.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
/// \param[in] graph The given function graph pointer handle.
|
||||
/// \param[in] inputs An array of operator's input nodes.
|
||||
/// \param[in] input_num The number of nodes in the array.
|
||||
/// \param[in] info An CustomOpInfo struct which describes the information of custom operator.
|
||||
///
|
||||
/// \return The created custom operator node.
|
||||
MIND_C_API NodeHandle MSNewCustomOp(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
|
||||
CustomOpInfo info);
|
||||
|
||||
/// \brief Get specified input node of Operator.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
|
@ -138,7 +176,7 @@ MIND_C_API NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] shape_size The size of shape, i.e., the dimension of the Placeholder.
|
||||
///
|
||||
/// \return The created Placeholder node handle.
|
||||
MIND_C_API NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, TypeId type, const int64_t shape[],
|
||||
MIND_C_API NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size);
|
||||
|
||||
/// \brief Create a Variable node of tensor, which contains variable tensor data.
|
||||
|
@ -152,7 +190,7 @@ MIND_C_API NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] data_len The length of data.
|
||||
///
|
||||
/// \return The created Variable node handle.
|
||||
MIND_C_API NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *data, TypeId type,
|
||||
MIND_C_API NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *data, DataTypeC type,
|
||||
const int64_t shape[], size_t shape_size, size_t data_len);
|
||||
|
||||
/// \brief Create a Variable node from a Tensor instance with data.
|
||||
|
@ -191,7 +229,7 @@ MIND_C_API void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle n
|
|||
/// \param[in] data_len The length of data.
|
||||
///
|
||||
/// \return The created Constant node handle.
|
||||
MIND_C_API NodeHandle MSNewTensorConstant(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[],
|
||||
MIND_C_API NodeHandle MSNewTensorConstant(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size, size_t data_len);
|
||||
|
||||
/// \brief Create a Constant node from a Tensor instance with data.
|
||||
|
@ -274,7 +312,7 @@ MIND_C_API NodeHandle MSNewStringConstant(ResMgrHandle res_mgr, const char *str)
|
|||
/// \param[in] str The type.
|
||||
///
|
||||
/// \return The created Constant node handle.
|
||||
MIND_C_API NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type);
|
||||
MIND_C_API NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, DataTypeC type);
|
||||
|
||||
/// \brief Get value from the int32 scalar Constant node.
|
||||
///
|
||||
|
@ -348,7 +386,7 @@ MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHa
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained type value.
|
||||
MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
MIND_C_API DataTypeC MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Set Operator node name.
|
||||
///
|
||||
|
@ -372,4 +410,4 @@ MIND_C_API STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_IR_FUNC_TENSOR_H_
|
||||
#define MINDSPORE_CCSRC_C_API_IR_FUNC_TENSOR_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
|
||||
#define MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -39,7 +39,7 @@ extern "C" {
|
|||
/// \param[in] data_len The length of data in bytes.
|
||||
///
|
||||
/// \return The pointer of the created tensor instance.
|
||||
MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[],
|
||||
MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size, size_t data_len);
|
||||
|
||||
/// \brief Create a tensor with path to a space-sperated txt file.
|
||||
|
@ -51,8 +51,8 @@ MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId typ
|
|||
/// \param[in] path path to the file.
|
||||
///
|
||||
/// \return The pointer of the created tensor instance.
|
||||
MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
const char *path);
|
||||
MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size, const char *path);
|
||||
|
||||
/// \brief Create a tensor with input data buffer and given source data type.
|
||||
///
|
||||
|
@ -65,7 +65,7 @@ MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, c
|
|||
///
|
||||
/// \return The pointer of the created tensor instance.
|
||||
MIND_C_API TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int64_t shape[],
|
||||
size_t shape_size, TypeId tensor_type, TypeId src_type);
|
||||
size_t shape_size, DataTypeC tensor_type, DataTypeC src_type);
|
||||
|
||||
/// \brief Get the raw pointer of tensor data.
|
||||
///
|
||||
|
@ -82,7 +82,7 @@ MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor)
|
|||
/// \param[in] type The data type to be set.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type);
|
||||
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, DataTypeC type);
|
||||
|
||||
/// \brief Get tensor data type.
|
||||
///
|
||||
|
@ -90,7 +90,7 @@ MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor,
|
|||
/// \param[in] tensor The pointer of the tensor instance.
|
||||
///
|
||||
/// \return The data type of tensor.
|
||||
MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
MIND_C_API DataTypeC MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
|
||||
/// \brief Get the byte size of tensor data.
|
||||
///
|
||||
|
@ -142,4 +142,4 @@ MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tenso
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_C_API_IR_FUNC_TENSOR_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
|
||||
|
|
|
@ -40,7 +40,7 @@ STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, ConstNodeHand
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId type, const int64_t shape[], size_t shape_size) {
|
||||
STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, DataTypeC type, const int64_t shape[], size_t shape_size) {
|
||||
if (res_mgr == nullptr || node == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [shape] are nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -56,7 +56,7 @@ STATUS MSSetAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId type, const i
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSSetMultiAbstract(ResMgrHandle res_mgr, NodeHandle node, TypeId type, const int64_t **shapes,
|
||||
STATUS MSSetMultiAbstract(ResMgrHandle res_mgr, NodeHandle node, DataTypeC type, const int64_t **shapes,
|
||||
const size_t shape_sizes[], size_t abs_num) {
|
||||
if (res_mgr == nullptr || node == nullptr || shapes == nullptr || shape_sizes == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [shapes] or [shape_sizes] are nullptr.";
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "c_api/src/helper.h"
|
||||
#include "c_api/src/common.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "c_api/src/utils.h"
|
||||
|
||||
PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, ConstNodeHandle node) {
|
||||
auto src_node = GetSrcPtr<CNodePtr>(res_mgr, node);
|
||||
|
@ -101,7 +102,7 @@ STATUS MSOpSetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *a
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, TypeId value) {
|
||||
STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value) {
|
||||
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -116,7 +117,7 @@ STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_nam
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, TypeId value[],
|
||||
STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value[],
|
||||
size_t vec_size) {
|
||||
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
|
||||
|
@ -138,7 +139,7 @@ STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *att
|
|||
}
|
||||
|
||||
STATUS MSOpSetAttrArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, void *value, size_t vec_size,
|
||||
TypeId dataType) {
|
||||
DataTypeC data_type) {
|
||||
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -149,29 +150,29 @@ STATUS MSOpSetAttrArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_na
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
switch (dataType) {
|
||||
case TypeId::kNumberTypeBool: {
|
||||
switch (data_type) {
|
||||
case MS_BOOL: {
|
||||
std::vector<bool> vec_value(reinterpret_cast<bool *>(value), reinterpret_cast<bool *>(value) + vec_size);
|
||||
prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
case MS_INT32: {
|
||||
std::vector<int32_t> vec_value(reinterpret_cast<int32_t *>(value), reinterpret_cast<int32_t *>(value) + vec_size);
|
||||
prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
case MS_INT64: {
|
||||
std::vector<int64_t> vec_value(reinterpret_cast<int64_t *>(value), reinterpret_cast<int64_t *>(value) + vec_size);
|
||||
prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
case MS_FLOAT32: {
|
||||
std::vector<float> vec_value(reinterpret_cast<float *>(value), reinterpret_cast<float *>(value) + vec_size);
|
||||
prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ TypeId: " << dataType << " , Attribute name: " << attr_name
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ DataTypeC ID: " << data_type << " , Attribute name: " << attr_name
|
||||
<< std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -297,37 +298,36 @@ AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v) {
|
|||
return GetRawPtr(res_mgr, value);
|
||||
}
|
||||
|
||||
AttrHandle MSOpNewAttrs(ResMgrHandle res_mgr, void *value, size_t vec_size, TypeId data_type) {
|
||||
AttrHandle MSNewAttrArray(ResMgrHandle res_mgr, void *value, size_t vec_size, DataTypeC data_type) {
|
||||
if (res_mgr == nullptr || value == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [value_vec] is nullptr.";
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [value] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
mindspore::ValuePtr value_node;
|
||||
|
||||
mindspore::ValuePtr value_ptr;
|
||||
switch (data_type) {
|
||||
case TypeId::kNumberTypeBool: {
|
||||
case MS_BOOL: {
|
||||
std::vector<bool> vec_value(reinterpret_cast<bool *>(value), reinterpret_cast<bool *>(value) + vec_size);
|
||||
value_node = mindspore::MakeValue(vec_value);
|
||||
value_ptr = mindspore::MakeValue(vec_value);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
case MS_INT32: {
|
||||
std::vector<int32_t> vec_value(reinterpret_cast<int32_t *>(value), reinterpret_cast<int32_t *>(value) + vec_size);
|
||||
value_node = mindspore::MakeValue(vec_value);
|
||||
value_ptr = mindspore::MakeValue(vec_value);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
case MS_INT64: {
|
||||
std::vector<int64_t> vec_value(reinterpret_cast<int64_t *>(value), reinterpret_cast<int64_t *>(value) + vec_size);
|
||||
value_node = mindspore::MakeValue(vec_value);
|
||||
value_ptr = mindspore::MakeValue(vec_value);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
case MS_FLOAT32: {
|
||||
std::vector<float> vec_value(reinterpret_cast<float *>(value), reinterpret_cast<float *>(value) + vec_size);
|
||||
value_node = mindspore::MakeValue(vec_value);
|
||||
value_ptr = mindspore::MakeValue(vec_value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ TypeId: " << data_type << std::endl;
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ DataTypeC ID: " << data_type << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
return GetRawPtr(res_mgr, value_node);
|
||||
return GetRawPtr(res_mgr, value_ptr);
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ using PrimitiveImpl = mindspore::Primitive;
|
|||
using TensorImpl = mindspore::tensor::Tensor;
|
||||
using ScalarImpl = mindspore::Scalar;
|
||||
using TypeImpl = mindspore::Type;
|
||||
using TensorTypeImpl = mindspore::TensorType;
|
||||
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
|
||||
using AbstractTensorImpl = mindspore::abstract::AbstractTensor;
|
||||
using AbstractScalarImpl = mindspore::abstract::AbstractScalar;
|
||||
|
@ -51,6 +52,7 @@ using Float32ImmImpl = mindspore::FP32Imm;
|
|||
using BasePtr = mindspore::BasePtr;
|
||||
using ValuePtr = mindspore::ValuePtr;
|
||||
using TypePtr = mindspore::TypePtr;
|
||||
using TensorTypePtr = mindspore::TensorTypePtr;
|
||||
using ScalarPtr = mindspore::ScalarPtr;
|
||||
using Int32ImmPtr = mindspore::Int32ImmPtr;
|
||||
using Int64ImmPtr = mindspore::Int64ImmPtr;
|
||||
|
@ -69,5 +71,9 @@ using FuncGraphPtr = mindspore::FuncGraphPtr;
|
|||
using FuncGraphManagerPtr = std::shared_ptr<mindspore::FuncGraphManager>;
|
||||
|
||||
using AttrMap = mindspore::HashMap<std::string, ValuePtr>;
|
||||
using BaseShapePtr = mindspore::abstract::BaseShapePtr;
|
||||
using Shape = mindspore::abstract::Shape;
|
||||
using ShapePtr = mindspore::abstract::ShapePtr;
|
||||
using TupleShape = mindspore::abstract::TupleShape;
|
||||
|
||||
#endif // MINDSPORE_CCSRC_C_API_SRC_COMMON_H_
|
||||
|
|
|
@ -32,7 +32,6 @@ void MSResourceManagerDestroy(ResMgrHandle res_mgr) {
|
|||
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
|
||||
delete res_mgr_ptr;
|
||||
res_mgr_ptr = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
void MSSetEagerMode(bool eager_mode) {
|
||||
|
@ -40,7 +39,6 @@ void MSSetEagerMode(bool eager_mode) {
|
|||
MS_LOG(WARNING) << "Set Execution mode: " << mode;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<int>(mindspore::MS_CTX_EXECUTION_MODE, mode);
|
||||
return;
|
||||
}
|
||||
|
||||
STATUS MSSetBackendPolicy(const char *policy) {
|
||||
|
@ -53,7 +51,6 @@ void MSSetDeviceTarget(const char *device) {
|
|||
MS_LOG(WARNING) << "Set Device Target: " << device;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET, device);
|
||||
return;
|
||||
}
|
||||
|
||||
STATUS MSGetDeviceTarget(char str_buf[], size_t str_len) {
|
||||
|
@ -75,28 +72,24 @@ void MSSetDeviceId(uint32_t deviceId) {
|
|||
MS_LOG(WARNING) << "Set Device ID: " << deviceId;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<std::uint32_t>(mindspore::MS_CTX_DEVICE_ID, deviceId);
|
||||
return;
|
||||
}
|
||||
|
||||
void MSSetGraphsSaveMode(int save_mode) {
|
||||
MS_LOG(DEBUG) << "Set Graphs Save Mode: " << save_mode;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<int>(mindspore::MS_CTX_SAVE_GRAPHS_FLAG, save_mode);
|
||||
return;
|
||||
}
|
||||
|
||||
void MSSetGraphsSavePath(const char *save_path) {
|
||||
MS_LOG(DEBUG) << "Set Graphs Save Path: " << save_path;
|
||||
auto context = mindspore::MsContext::GetInstance();
|
||||
context->set_param<std::string>(mindspore::MS_CTX_SAVE_GRAPHS_PATH, save_path);
|
||||
return;
|
||||
}
|
||||
|
||||
void MSSetInfer(ResMgrHandle res_mgr, bool infer) {
|
||||
MS_LOG(DEBUG) << "Set Infer Graph: " << infer;
|
||||
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
|
||||
res_mgr_ptr->SetInfer(infer);
|
||||
return;
|
||||
}
|
||||
|
||||
bool MSGetInfer(ResMgrHandle res_mgr) {
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_C_API_HELPER_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_SRC_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_C_API_SRC_HELPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include "base/base.h"
|
||||
|
@ -36,4 +36,4 @@ T GetSrcPtr(ResMgrHandle res_mgr, ConstHandle raw_ptr) {
|
|||
auto res_ptr = base_ptr->cast<T>();
|
||||
return res_ptr;
|
||||
}
|
||||
#endif // MINDSPORE_CCSRC_C_API_HELPER_H_
|
||||
#endif // MINDSPORE_CCSRC_C_API_SRC_HELPER_H_
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "c_api/include/node.h"
|
||||
#include "c_api/include/attribute.h"
|
||||
#include "c_api/src/helper.h"
|
||||
#include "c_api/src/common.h"
|
||||
#include "c_api/src/utils.h"
|
||||
|
@ -25,10 +26,13 @@
|
|||
#include "ir/scope.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
|
||||
constexpr size_t firstInIdx = 1;
|
||||
constexpr size_t secondInIdx = 2;
|
||||
constexpr size_t switchInputNum = 3;
|
||||
static const size_t maxMallocSize = GetMaxMallocSize();
|
||||
|
||||
STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_names, AttrHandle attrs[],
|
||||
size_t attr_num) {
|
||||
|
@ -40,7 +44,7 @@ STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_name
|
|||
}
|
||||
auto value = GetSrcPtr<ValuePtr>(res_mgr, attrs[i]);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Get source pointer failed.";
|
||||
MS_LOG(ERROR) << "Get attribute's source pointer failed, attribute index: " << i;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::string name(attr_names[i]);
|
||||
|
@ -187,6 +191,10 @@ CNodePtr BuildSwitchStructure(ResMgrHandle res_mgr, GraphHandle graph, NodeHandl
|
|||
MS_EXCEPTION_IF_NULL(switch_input);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_num == switchInputNum, "Switch's input number must be 3!");
|
||||
NodeHandle switch_op = MSNewOp(res_mgr, graph, "Switch", switch_input, input_num, NULL, NULL, 0);
|
||||
if (switch_op == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Switch op failed!";
|
||||
return nullptr;
|
||||
}
|
||||
auto src_switch = GetSrcPtr<CNodePtr>(res_mgr, switch_op);
|
||||
MS_EXCEPTION_IF_NULL(src_switch);
|
||||
auto fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
|
||||
|
@ -402,6 +410,203 @@ NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, Grap
|
|||
}
|
||||
}
|
||||
|
||||
BaseShapePtr CustomOpInferShape(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto build_shape_func = [](int64_t **out_shapes, size_t *out_dims, size_t out_num) -> BaseShapePtr {
|
||||
BaseShapePtr infer_shape;
|
||||
if (out_num == 1) {
|
||||
int64_t *shape = out_shapes[0];
|
||||
ShapeVector shape_vec(shape, shape + out_dims[0]);
|
||||
infer_shape = std::make_shared<Shape>(shape_vec);
|
||||
} else {
|
||||
std::vector<BaseShapePtr> output_list;
|
||||
for (size_t i = 0; i < out_num; i++) {
|
||||
int64_t *shape = out_shapes[i];
|
||||
ShapeVector shape_vec(shape, shape + out_dims[i]);
|
||||
auto each_shape = std::make_shared<Shape>(shape_vec);
|
||||
output_list.push_back(each_shape);
|
||||
}
|
||||
infer_shape = std::make_shared<TupleShape>(output_list);
|
||||
}
|
||||
return infer_shape;
|
||||
};
|
||||
if (info.output_shapes != nullptr) {
|
||||
if (info.output_dims == nullptr) {
|
||||
MS_LOG(ERROR) << "Output dims must be given if output shapes are specified!";
|
||||
return nullptr;
|
||||
}
|
||||
BaseShapePtr infer_shape = build_shape_func(info.output_shapes, info.output_dims, info.output_num);
|
||||
return infer_shape;
|
||||
} else if (info.shape_infer_func != nullptr) {
|
||||
size_t input_num = info.input_num;
|
||||
size_t output_num = info.output_num;
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num * sizeof(size_t) > maxMallocSize, nullptr,
|
||||
"The input_num is too large for memory allocation.");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num * sizeof(size_t) > maxMallocSize, nullptr,
|
||||
"The output_num is too large for memory allocation.");
|
||||
auto *out_dims_arr = new size_t[output_num];
|
||||
auto **out_shapes_arr = new int64_t *[output_num];
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
out_shapes_arr[i] = new int64_t[MAX_DIMS];
|
||||
}
|
||||
auto *in_dims_arr = new size_t[input_num];
|
||||
auto **in_shapes_arr = new int64_t *[input_num];
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto in_shape = input_args[i]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(in_shape);
|
||||
auto in_shape_ptr = in_shape->cast<ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(in_shape_ptr);
|
||||
auto in_shape_vec = in_shape_ptr->shape();
|
||||
auto in_shape_dim = in_shape_vec.size();
|
||||
in_dims_arr[i] = in_shape_dim;
|
||||
in_shapes_arr[i] = new int64_t[in_shape_dim];
|
||||
for (size_t j = 0; j < in_shape_dim; j++) {
|
||||
in_shapes_arr[i][j] = in_shape_vec[j];
|
||||
}
|
||||
}
|
||||
auto ret = info.shape_infer_func(in_shapes_arr, in_dims_arr, input_num, out_shapes_arr, out_dims_arr, output_num);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed to call the shape infer function of custom op!";
|
||||
return nullptr;
|
||||
}
|
||||
BaseShapePtr infer_shape = build_shape_func(out_shapes_arr, out_dims_arr, output_num);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
delete[] in_shapes_arr[i];
|
||||
}
|
||||
delete[] in_shapes_arr;
|
||||
delete[] in_dims_arr;
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
delete[] out_shapes_arr[i];
|
||||
}
|
||||
delete[] out_shapes_arr;
|
||||
delete[] out_dims_arr;
|
||||
return infer_shape;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Either output shape or output shape infer function must be specified!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr CustomOpInferType(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto build_type_func = [](const DataTypeC *out_dtypes, size_t out_num) -> TypePtr {
|
||||
TypePtr infer_type;
|
||||
if (out_num == 1) {
|
||||
DataTypeC dtype = out_dtypes[0];
|
||||
auto cxx_type = mindspore::TypeId(dtype);
|
||||
infer_type = mindspore::TypeIdToType(cxx_type);
|
||||
} else {
|
||||
std::vector<TypePtr> type_list;
|
||||
for (size_t i = 0; i < out_num; i++) {
|
||||
DataTypeC dtype = out_dtypes[i];
|
||||
auto cxx_type = mindspore::TypeId(dtype);
|
||||
auto type_val = mindspore::TypeIdToType(cxx_type);
|
||||
type_list.push_back(type_val);
|
||||
}
|
||||
infer_type = std::make_shared<mindspore::Tuple>(type_list);
|
||||
}
|
||||
return infer_type;
|
||||
};
|
||||
if (info.output_dtypes != nullptr) {
|
||||
TypePtr infer_dtype = build_type_func(info.output_dtypes, info.output_num);
|
||||
return infer_dtype;
|
||||
} else if (info.shape_infer_func != nullptr) {
|
||||
size_t input_num = info.input_num;
|
||||
size_t output_num = info.output_num;
|
||||
auto *in_dtypes_arr = new DataTypeC[input_num];
|
||||
auto *out_dtypes_arr = new DataTypeC[output_num];
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto in_type = input_args[i]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(in_type);
|
||||
auto real_type = in_type;
|
||||
if (in_type->isa<TensorTypeImpl>()) {
|
||||
auto tensor_type = in_type->cast<TensorTypePtr>();
|
||||
real_type = tensor_type->element();
|
||||
}
|
||||
auto in_type_id = (enum DataTypeC)(real_type->type_id());
|
||||
in_dtypes_arr[i] = in_type_id;
|
||||
}
|
||||
STATUS ret = info.dtype_infer_func(in_dtypes_arr, input_num, out_dtypes_arr, output_num);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed to call the dtype infer function of custom op!";
|
||||
return nullptr;
|
||||
}
|
||||
TypePtr infer_dtype = build_type_func(out_dtypes_arr, output_num);
|
||||
delete[] in_dtypes_arr;
|
||||
delete[] out_dtypes_arr;
|
||||
return infer_dtype;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Either output dtype or output dtype infer function must be specified!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
NodeHandle MSNewCustomOp(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
|
||||
CustomOpInfo info) {
|
||||
if (res_mgr == nullptr || graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num != info.input_num, nullptr,
|
||||
"Input node number is not matched with the input number specified in custom op info.");
|
||||
auto ret = CheckCustomOpInfo(info);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Invalid custom op info.");
|
||||
try {
|
||||
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
|
||||
auto org_infer = res_mgr_ptr->GetInfer();
|
||||
res_mgr_ptr->SetInfer(false);
|
||||
NodeHandle custom_op =
|
||||
MSNewOp(res_mgr, graph, "Custom", inputs, info.input_num, info.attr_name, info.attr_value, info.attr_num);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(custom_op == nullptr, nullptr, "Create Custom op failed!");
|
||||
res_mgr_ptr->SetInfer(org_infer);
|
||||
// Supplement necessary attributes
|
||||
ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncType, info.func_type);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func type attribute failed.");
|
||||
ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncName, info.func_name);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func name attribute failed.");
|
||||
// Build json object
|
||||
nlohmann::json json_obj = ConvertOpInfoToJson(info);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(json_obj.empty(), nullptr, "Failed to convert op info to json.");
|
||||
// Create op info and set info map
|
||||
auto op_name = json_obj.at(mindspore::kernel::kOpName).get<std::string>();
|
||||
auto imply_type = json_obj.at(mindspore::kernel::kImplyType).get<std::string>();
|
||||
std::string func_name = info.func_name;
|
||||
std::string target_name = info.target;
|
||||
auto iter = mindspore::kernel::kImplyTypeStrToEnumMap.find(imply_type);
|
||||
if (iter == mindspore::kernel::kImplyTypeStrToEnumMap.end()) {
|
||||
MS_LOG(ERROR) << "Not support imply_type: " << imply_type;
|
||||
return nullptr;
|
||||
}
|
||||
auto op_info = mindspore::kernel::OpLib::DecodeOpInfo(json_obj, iter->second, "");
|
||||
if (op_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Decode op info failed: func_name: " << func_name << " imply_type " << imply_type;
|
||||
return nullptr;
|
||||
}
|
||||
op_info->set_processor(imply_type);
|
||||
auto key = op_name + imply_type;
|
||||
auto &op_infos = mindspore::kernel::OpLib::GetOpInfoMap();
|
||||
(void)op_infos[iter->second].insert(std::pair<std::string, mindspore::kernel::OpInfoPtr>(key, op_info));
|
||||
// Infer shape and type
|
||||
mindspore::AbstractBasePtrList abs_list{};
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto in_node = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
abs_list.push_back(in_node->abstract());
|
||||
}
|
||||
auto infer_shape = CustomOpInferShape(info, abs_list);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(infer_shape == nullptr, nullptr, "Custom op infer shape failed!");
|
||||
auto infer_type = CustomOpInferType(info, abs_list);
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(infer_type == nullptr, nullptr, "Custom op infer type failed!");
|
||||
AbstractBasePtr custom_abs = mindspore::abstract::MakeAbstract(infer_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(custom_abs);
|
||||
auto src_op = GetSrcPtr<CNodePtr>(res_mgr, custom_op);
|
||||
MS_EXCEPTION_IF_NULL(src_op);
|
||||
src_op->set_abstract(custom_abs);
|
||||
return custom_op;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get custom op failed. Error info: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i) {
|
||||
if (res_mgr == nullptr || op == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
|
||||
|
@ -502,7 +707,7 @@ NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraph
|
|||
return GetRawPtr(res_mgr, cnode);
|
||||
}
|
||||
|
||||
NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, TypeId type, const int64_t shape[],
|
||||
NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size) {
|
||||
if (res_mgr == nullptr || graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
|
||||
|
@ -523,8 +728,8 @@ NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, TypeId type
|
|||
return GetRawPtr(res_mgr, param);
|
||||
}
|
||||
|
||||
NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *data, TypeId type, const int64_t shape[],
|
||||
size_t shape_size, size_t data_len) {
|
||||
NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *data, DataTypeC type,
|
||||
const int64_t shape[], size_t shape_size, size_t data_len) {
|
||||
if (res_mgr == nullptr || graph == nullptr || data == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [data] or [shape] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -614,8 +819,8 @@ void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
|
|||
}
|
||||
}
|
||||
|
||||
NodeHandle MSNewTensorConstant(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
size_t data_len) {
|
||||
NodeHandle MSNewTensorConstant(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
|
||||
size_t shape_size, size_t data_len) {
|
||||
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [data] or [shape] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -772,7 +977,7 @@ NodeHandle MSNewTupleConstantInt64(ResMgrHandle res_mgr, const int64_t vec[], si
|
|||
return GetRawPtr(res_mgr, value_node);
|
||||
}
|
||||
|
||||
NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type) {
|
||||
NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, DataTypeC type) {
|
||||
MS_LOG(INFO) << "New Type Value: " << type;
|
||||
if (res_mgr == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
|
||||
|
@ -1013,16 +1218,16 @@ STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node,
|
|||
}
|
||||
}
|
||||
|
||||
TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
DataTypeC MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Type Constant Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
if (res_mgr == nullptr || node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
|
||||
*error = RET_NULL_PTR;
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
try {
|
||||
auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
|
||||
|
@ -1030,13 +1235,13 @@ TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS
|
|||
auto val = node_impl->value();
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
auto val_type = val->cast<TypePtr>();
|
||||
auto ret_val = static_cast<TypeId>(val_type->type_id());
|
||||
auto ret_val = static_cast<DataTypeC>(val_type->type_id());
|
||||
*error = RET_OK;
|
||||
return ret_val;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get Type Constant value failed. Error info: " << e.what();
|
||||
*error = RET_ERROR;
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,10 +36,9 @@ void GetDataByFile(std::vector<T> data, const char *path, size_t *elem_size) {
|
|||
}
|
||||
fin.close();
|
||||
*elem_size = data.size() * sizeof(T);
|
||||
return;
|
||||
}
|
||||
|
||||
TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[], size_t shape_size,
|
||||
size_t data_len) {
|
||||
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [data] or [shape] is nullptr.";
|
||||
|
@ -56,7 +55,7 @@ TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const in
|
|||
return GetRawPtr(res_mgr, tensor);
|
||||
}
|
||||
|
||||
TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
|
||||
TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, DataTypeC type, const int64_t shape[], size_t shape_size,
|
||||
const char *path) {
|
||||
if (res_mgr == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [shape] is nullptr.";
|
||||
|
@ -67,26 +66,26 @@ TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_
|
|||
try {
|
||||
size_t data_len;
|
||||
switch (type) {
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
case MS_INT32: {
|
||||
std::vector<int32_t> data;
|
||||
(void)GetDataByFile<int32_t>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
case MS_INT64: {
|
||||
std::vector<int64_t> data;
|
||||
(void)GetDataByFile<int64_t>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
case MS_FLOAT32: {
|
||||
std::vector<float> data;
|
||||
(void)GetDataByFile<float>(data, path, &data_len);
|
||||
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ TypeId: " << type << std::endl;
|
||||
MS_LOG(ERROR) << "Unrecognized datatype w/ DataTypeC ID: " << type << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
|
@ -97,7 +96,7 @@ TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_
|
|||
}
|
||||
|
||||
TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int64_t shape[], size_t shape_size,
|
||||
TypeId tensor_type, TypeId src_type) {
|
||||
DataTypeC tensor_type, DataTypeC src_type) {
|
||||
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [data] or [shape] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -126,7 +125,7 @@ void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor) {
|
|||
return src_tensor->data_c();
|
||||
}
|
||||
|
||||
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type) {
|
||||
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, DataTypeC type) {
|
||||
if (res_mgr == nullptr || tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
|
||||
return RET_ERROR;
|
||||
|
@ -140,24 +139,24 @@ STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId typ
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
DataTypeC MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
if (res_mgr == nullptr || tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
|
||||
*error = RET_NULL_PTR;
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
auto src_tensor = GetSrcPtr<TensorPtr>(res_mgr, tensor);
|
||||
if (src_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Get source pointer failed.";
|
||||
*error = RET_NULL_PTR;
|
||||
return (enum TypeId)0;
|
||||
return MS_INVALID_TYPE;
|
||||
}
|
||||
*error = RET_OK;
|
||||
return (enum TypeId)(src_tensor->data_type_c());
|
||||
return (enum DataTypeC)(src_tensor->data_type_c());
|
||||
}
|
||||
|
||||
size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
|
|
|
@ -80,3 +80,103 @@ AbstractBasePtr GetAbstract(const TypePtr &type_ptr, const int64_t shape[], size
|
|||
ShapeVector shape_vec(shape, shape + shape_size);
|
||||
return std::make_shared<AbstractTensorImpl>(type_ptr, shape_vec);
|
||||
}
|
||||
|
||||
STATUS CheckCustomOpInfo(const CustomOpInfo &info) {
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.func_name != nullptr, RET_ERROR, "The func_name of custom op must be specified!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.func_type != nullptr, RET_ERROR, "The func_type of custom op must be specified!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.target != nullptr, RET_ERROR, "The target of custom op must be specified!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.input_name != nullptr, RET_ERROR,
|
||||
"The input_name of custom op must be specified!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.output_name != nullptr, RET_ERROR,
|
||||
"The output_name of custom op must be specified!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.input_num > 0, RET_ERROR, "The input_num of custom op must be a positive value!");
|
||||
MS_ERROR_IF_FALSE_W_RET_N_LOG(info.output_num > 0, RET_ERROR,
|
||||
"The output_num of custom op must be a positive value!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.dtype_infer_func == nullptr && info.output_dtypes == nullptr, RET_ERROR,
|
||||
"Either dtype infer function or output shape must be specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.dtype_infer_func != nullptr && info.output_dtypes != nullptr, RET_ERROR,
|
||||
"Only one should be specified between dtype infer function and output shape!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.shape_infer_func == nullptr && info.output_shapes == nullptr, RET_ERROR,
|
||||
"Either shape infer function or output shape must be specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.shape_infer_func != nullptr && info.output_shapes != nullptr, RET_ERROR,
|
||||
"Only one should be specified between shape infer function and output shape!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.output_shapes != nullptr && info.output_dims == nullptr, RET_ERROR,
|
||||
"Output dims must be specified if output_shapes are given!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.attr_name != nullptr && info.attr_num == 0, RET_ERROR,
|
||||
"The attr_num of custom op must be none-zero if attr_name is specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.attr_name == nullptr && info.attr_num != 0, RET_ERROR,
|
||||
"The attr_num of custom op must be zero if attr_name is not specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.dtype_formats != nullptr && info.dtype_formats_num == 0, RET_ERROR,
|
||||
"The dtype_formats_num of custom op must be none-zero if dtype_formats is specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(info.dtype_formats == nullptr && info.dtype_formats_num != 0, RET_ERROR,
|
||||
"The dtype_formats_num of custom op must be zero if dtype_formats is not specified!");
|
||||
MS_ERROR_IF_TRUE_W_RET_N_LOG(std::string(info.func_name).find(".so:") == std::string::npos, RET_ERROR,
|
||||
"so file path and function name must be provided in func_name!");
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
nlohmann::json ConvertOpInfoToJson(const CustomOpInfo &info) {
|
||||
nlohmann::json obj;
|
||||
obj["attr"] = {};
|
||||
std::string target = info.target;
|
||||
obj["target"] = target;
|
||||
obj["op_name"] = "Custom" + std::string(info.func_name);
|
||||
obj["fusion_tyoe"] = "OPAQUE";
|
||||
if (info.dtype_formats != nullptr) {
|
||||
std::vector<std::vector<std::string>> dtype_formats;
|
||||
for (size_t i = 0; i < info.dtype_formats_num; i++) {
|
||||
for (size_t j = 0; j < info.input_num + info.output_num; j++) {
|
||||
auto iter = kDTypeFmtEnumToStrMap.find(info.dtype_formats[i][j]);
|
||||
if (iter == kDTypeFmtEnumToStrMap.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported DTypeFormat: " << info.dtype_formats[i][j];
|
||||
return {};
|
||||
}
|
||||
dtype_formats.push_back(iter->second);
|
||||
}
|
||||
}
|
||||
obj["dtype_format"] = {dtype_formats};
|
||||
}
|
||||
std::vector<nlohmann::json> js_inputs;
|
||||
for (size_t i = 0; i < info.input_num; i++) {
|
||||
nlohmann::json js_input;
|
||||
js_input["index"] = i;
|
||||
js_input["name"] = std::string(info.input_name[i]);
|
||||
js_input["paramType"] = "required";
|
||||
js_inputs.push_back(js_input);
|
||||
}
|
||||
obj["inputs"] = js_inputs;
|
||||
std::vector<nlohmann::json> js_outputs;
|
||||
for (size_t i = 0; i < info.output_num; i++) {
|
||||
nlohmann::json js_output;
|
||||
js_output["index"] = i;
|
||||
js_output["name"] = std::string(info.output_name[i]);
|
||||
js_output["paramType"] = "required";
|
||||
js_outputs.push_back(js_output);
|
||||
}
|
||||
obj["outputs"] = js_outputs;
|
||||
auto aot_imply_type = target == "Ascend" ? "BiSheng" : target;
|
||||
const std::map<std::string, std::string> func_type_to_imply_type = {
|
||||
{"hybrid", "AKG"}, {"akg", "AKG"}, {"tbe", "TBE"}, {"aicpu", "AICPU"},
|
||||
{"pyfunc", target}, {"julia", target}, {"aot", aot_imply_type}};
|
||||
auto iter = func_type_to_imply_type.find(std::string(info.func_type));
|
||||
if (iter == func_type_to_imply_type.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported function type: " << std::string(info.func_type);
|
||||
return {};
|
||||
}
|
||||
auto imply_type = iter->second;
|
||||
obj["imply_type"] = imply_type;
|
||||
return obj;
|
||||
}
|
||||
|
||||
size_t GetMaxMallocSize() {
|
||||
size_t max_malloc_size = 0;
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
MEMORYSTATUSEX status;
|
||||
status.dwLength = sizeof(status);
|
||||
GlobalMemoryStatusEx(&status);
|
||||
max_malloc_size = static_cast<size_t>(status.ullTotalPhys);
|
||||
#else
|
||||
max_malloc_size = static_cast<size_t>(sysconf(_SC_PHYS_PAGES)) * static_cast<size_t>(sysconf(_SC_PAGESIZE));
|
||||
#endif
|
||||
return max_malloc_size;
|
||||
}
|
||||
|
|
|
@ -14,20 +14,232 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_C_API_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_C_API_UTILS_H_
|
||||
#ifndef MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "base/base.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "c_api/src/resource_manager.h"
|
||||
#include "c_api/include/context.h"
|
||||
#include "c_api/include/node.h"
|
||||
#include "c_api/src/common.h"
|
||||
|
||||
const std::map<DTypeFormat, std::vector<std::string>> kDTypeFmtEnumToStrMap = {
|
||||
{None_None, {"", ""}},
|
||||
{None_Default, {"", "DefaultFormat"}},
|
||||
{BOOL_None, {"bool", ""}},
|
||||
{BOOL_Default, {"bool", "DefaultFormat"}},
|
||||
{BOOL_5HD, {"bool", "NC1HWC0"}},
|
||||
{BOOL_FracZ, {"bool", "FRACTAL_Z"}},
|
||||
{BOOL_FracNZ, {"bool", "FRACTAL_NZ"}},
|
||||
{BOOL_C1HWNCoC0, {"bool", "C1HWNCoC0"}},
|
||||
{BOOL_NCHW, {"bool", "NCHW"}},
|
||||
{BOOL_NHWC, {"bool", "NHWC"}},
|
||||
{BOOL_HWCN, {"bool", "HWCN"}},
|
||||
{BOOL_NDHWC, {"bool", "NDHWC"}},
|
||||
{BOOL_ChannelLast, {"bool", "ChannelLast"}},
|
||||
{BOOL_Default_Tuple, {"bool", "DefaultFormat", "tuple"}},
|
||||
{BOOL_Default_List, {"bool", "DefaultFormat", "list"}},
|
||||
{I8_None, {"int8", ""}},
|
||||
{I8_Default, {"int8", "DefaultFormat"}},
|
||||
{I8_5HD, {"int8", "NC1HWC0"}},
|
||||
{I8_FracZ, {"int8", "FRACTAL_Z"}},
|
||||
{I8_FracNZ, {"int8", "FRACTAL_NZ"}},
|
||||
{I8_C1HWNCoC0, {"int8", "C1HWNCoC0"}},
|
||||
{I8_NCHW, {"int8", "NCHW"}},
|
||||
{I8_NHWC, {"int8", "NHWC"}},
|
||||
{I8_HWCN, {"int8", "HWCN"}},
|
||||
{I8_NDHWC, {"int8", "NDHWC"}},
|
||||
{I8_NCDHW, {"int8", "NCDHW"}},
|
||||
{I8_ChannelLast, {"int8", "ChannelLast"}},
|
||||
{I8_NDC1HWC0, {"int8", "NDC1HWC0"}},
|
||||
{I8_NC1HWC0, {"int8", "NC1HWC0"}},
|
||||
{I8_Default_Tuple, {"int8", "DefaultFormat", "tuple"}},
|
||||
{I8_Default_List, {"int8", "DefaultFormat", "list"}},
|
||||
{U8_None, {"uint8", ""}},
|
||||
{U8_Default, {"uint8", "DefaultFormat"}},
|
||||
{U8_5HD, {"uint8", "NC1HWC0"}},
|
||||
{U8_FracZ, {"uint8", "FRACTAL_Z"}},
|
||||
{U8_FracNZ, {"uint8", "FRACTAL_NZ"}},
|
||||
{U8_C1HWNCoC0, {"uint8", "C1HWNCoC0"}},
|
||||
{U8_NCHW, {"uint8", "NCHW"}},
|
||||
{U8_NHWC, {"uint8", "NHWC"}},
|
||||
{U8_HWCN, {"uint8", "HWCN"}},
|
||||
{U8_NDHWC, {"uint8", "NDHWC"}},
|
||||
{U8_NCDHW, {"uint8", "NCDHW"}},
|
||||
{U8_ChannelLast, {"uint8", "ChannelLast"}},
|
||||
{U8_NDC1HWC0, {"uint8", "NDC1HWC0"}},
|
||||
{U8_NC1HWC0, {"uint8", "NC1HWC0"}},
|
||||
{U8_Default_Tuple, {"uint8", "DefaultFormat", "tuple"}},
|
||||
{U8_Default_List, {"uint8", "DefaultFormat", "list"}},
|
||||
{I16_None, {"int16", ""}},
|
||||
{I16_Default, {"int16", "DefaultFormat"}},
|
||||
{I16_5HD, {"int16", "NC1HWC0"}},
|
||||
{I16_FracZ, {"int16", "FRACTAL_Z"}},
|
||||
{I16_FracNZ, {"int16", "FRACTAL_NZ"}},
|
||||
{I16_C1HWNCoC0, {"int16", "C1HWNCoC0"}},
|
||||
{I16_NCHW, {"int16", "NCHW"}},
|
||||
{I16_NHWC, {"int16", "NHWC"}},
|
||||
{I16_HWCN, {"int16", "HWCN"}},
|
||||
{I16_NDHWC, {"int16", "NDHWC"}},
|
||||
{I16_ChannelLast, {"int16", "ChannelLast"}},
|
||||
{I16_Default_Tuple, {"int16", "DefaultFormat", "tuple"}},
|
||||
{I16_Default_List, {"int16", "DefaultFormat", "list"}},
|
||||
{U16_None, {"uint16", ""}},
|
||||
{U16_Default, {"uint16", "DefaultFormat"}},
|
||||
{U16_5HD, {"uint16", "NC1HWC0"}},
|
||||
{U16_FracZ, {"uint16", "FRACTAL_Z"}},
|
||||
{U16_FracNZ, {"uint16", "FRACTAL_NZ"}},
|
||||
{U16_C1HWNCoC0, {"uint16", "C1HWNCoC0"}},
|
||||
{U16_NCHW, {"uint16", "NCHW"}},
|
||||
{U16_NHWC, {"uint16", "NHWC"}},
|
||||
{U16_HWCN, {"uint16", "HWCN"}},
|
||||
{U16_NDHWC, {"uint16", "NDHWC"}},
|
||||
{U16_ChannelLast, {"uint16", "ChannelLast"}},
|
||||
{U16_Default_Tuple, {"uint16", "DefaultFormat", "tuple"}},
|
||||
{U16_Default_List, {"uint16", "DefaultFormat", "list"}},
|
||||
{I32_None, {"int32", ""}},
|
||||
{I32_Default, {"int32", "DefaultFormat"}},
|
||||
{I32_5HD, {"int32", "NC1HWC0"}},
|
||||
{I32_FracZ, {"int32", "FRACTAL_Z"}},
|
||||
{I32_FracNZ, {"int32", "FRACTAL_NZ"}},
|
||||
{I32_C1HWNCoC0, {"int32", "C1HWNCoC0"}},
|
||||
{I32_NCHW, {"int32", "NCHW"}},
|
||||
{I32_NHWC, {"int32", "NHWC"}},
|
||||
{I32_HWCN, {"int32", "HWCN"}},
|
||||
{I32_NDHWC, {"int32", "NDHWC"}},
|
||||
{I32_NDC1HWC0, {"int32", "NDC1HWC0"}},
|
||||
{I32_NCDHW, {"int32", "NCDHW"}},
|
||||
{I32_ChannelLast, {"int32", "ChannelLast"}},
|
||||
{I32_Default_Tuple, {"int32", "DefaultFormat", "tuple"}},
|
||||
{I32_Default_List, {"int32", "DefaultFormat", "list"}},
|
||||
{U32_None, {"uint32", ""}},
|
||||
{U32_Default, {"uint32", "DefaultFormat"}},
|
||||
{U32_5HD, {"uint32", "NC1HWC0"}},
|
||||
{U32_FracZ, {"uint32", "FRACTAL_Z"}},
|
||||
{U32_FracNZ, {"uint32", "FRACTAL_NZ"}},
|
||||
{U32_C1HWNCoC0, {"uint32", "C1HWNCoC0"}},
|
||||
{U32_NCHW, {"uint32", "NCHW"}},
|
||||
{U32_NHWC, {"uint32", "NHWC"}},
|
||||
{U32_HWCN, {"uint32", "HWCN"}},
|
||||
{U32_NDHWC, {"uint32", "NDHWC"}},
|
||||
{U32_ChannelLast, {"uint32", "ChannelLast"}},
|
||||
{U32_Default_Tuple, {"uint32", "DefaultFormat", "tuple"}},
|
||||
{U32_Default_List, {"uint32", "DefaultFormat", "list"}},
|
||||
{I64_None, {"int64", ""}},
|
||||
{I64_Default, {"int64", "DefaultFormat"}},
|
||||
{I64_5HD, {"int64", "NC1HWC0"}},
|
||||
{I64_FracZ, {"int64", "FRACTAL_Z"}},
|
||||
{I64_FracNZ, {"int64", "FRACTAL_NZ"}},
|
||||
{I64_C1HWNCoC0, {"int64", "C1HWNCoC0"}},
|
||||
{I64_NCHW, {"int64", "NCHW"}},
|
||||
{I64_NHWC, {"int64", "NHWC"}},
|
||||
{I64_HWCN, {"int64", "HWCN"}},
|
||||
{I64_NDHWC, {"int64", "NDHWC"}},
|
||||
{I64_ChannelLast, {"int64", "ChannelLast"}},
|
||||
{I64_Default_Tuple, {"int64", "DefaultFormat", "tuple"}},
|
||||
{I64_Default_List, {"int64", "DefaultFormat", "list"}},
|
||||
{U64_None, {"uint64", ""}},
|
||||
{U64_Default, {"uint64", "DefaultFormat"}},
|
||||
{U64_5HD, {"uint64", "NC1HWC0"}},
|
||||
{U64_FracZ, {"uint64", "FRACTAL_Z"}},
|
||||
{U64_FracNZ, {"uint64", "FRACTAL_NZ"}},
|
||||
{U64_C1HWNCoC0, {"uint64", "C1HWNCoC0"}},
|
||||
{U64_NCHW, {"uint64", "NCHW"}},
|
||||
{U64_NHWC, {"uint64", "NHWC"}},
|
||||
{U64_HWCN, {"uint64", "HWCN"}},
|
||||
{U64_NDHWC, {"uint64", "NDHWC"}},
|
||||
{U64_ChannelLast, {"uint64", "ChannelLast"}},
|
||||
{U64_Default_Tuple, {"uint64", "DefaultFormat", "tuple"}},
|
||||
{U64_Default_List, {"uint64", "DefaultFormat", "list"}},
|
||||
{F16_None, {"float16", ""}},
|
||||
{F16_Default, {"float16", "DefaultFormat"}},
|
||||
{F16_5HD, {"float16", "NC1HWC0"}},
|
||||
{F16_FracZ, {"float16", "FRACTAL_Z"}},
|
||||
{F16_FracNZ, {"float16", "FRACTAL_NZ"}},
|
||||
{F16_C1HWNCoC0, {"float16", "C1HWNCoC0"}},
|
||||
{F16_NCHW, {"float16", "NCHW"}},
|
||||
{F16_NHWC, {"float16", "NHWC"}},
|
||||
{F16_HWCN, {"float16", "HWCN"}},
|
||||
{F16_NDHWC, {"float16", "NDHWC"}},
|
||||
{F16_NCDHW, {"float16", "NCDHW"}},
|
||||
{F16_DHWCN, {"float16", "DHWCN"}},
|
||||
{F16_NDC1HWC0, {"float16", "NDC1HWC0"}},
|
||||
{F16_FRACTAL_Z_3D, {"float16", "FRACTAL_Z_3D"}},
|
||||
{F16_FracZNLSTM, {"float16", "FRACTAL_ZN_LSTM"}},
|
||||
{F16_FracZNRNN, {"float16", "FRACTAL_ZN_RNN"}},
|
||||
{F16_ND_RNNBIAS, {"float16", "ND_RNN_BIAS"}},
|
||||
{F16_ChannelLast, {"float16", "ChannelLast"}},
|
||||
{F16_Default_Tuple, {"float16", "DefaultFormat", "tuple"}},
|
||||
{F16_Default_List, {"float16", "DefaultFormat", "list"}},
|
||||
{F32_None, {"float32", ""}},
|
||||
{F32_Default, {"float32", "DefaultFormat"}},
|
||||
{F32_5HD, {"float32", "NC1HWC0"}},
|
||||
{F32_FracZ, {"float32", "FRACTAL_Z"}},
|
||||
{F32_FracNZ, {"float32", "FRACTAL_NZ"}},
|
||||
{F32_C1HWNCoC0, {"float32", "C1HWNCoC0"}},
|
||||
{F32_NCHW, {"float32", "NCHW"}},
|
||||
{F32_NHWC, {"float32", "NHWC"}},
|
||||
{F32_HWCN, {"float32", "HWCN"}},
|
||||
{F32_NDHWC, {"float32", "NDHWC"}},
|
||||
{F32_NCDHW, {"float32", "NCDHW"}},
|
||||
{F32_DHWCN, {"float32", "DHWCN"}},
|
||||
{F32_NDC1HWC0, {"float32", "NDC1HWC0"}},
|
||||
{F32_FRACTAL_Z_3D, {"float32", "FRACTAL_Z_3D"}},
|
||||
{F32_FracZNLSTM, {"float32", "FRACTAL_ZN_LSTM"}},
|
||||
{F32_FracZNRNN, {"float32", "FRACTAL_ZN_RNN"}},
|
||||
{F32_ND_RNNBIAS, {"float32", "ND_RNN_BIAS"}},
|
||||
{F32_ChannelLast, {"float32", "ChannelLast"}},
|
||||
{F32_Default_Tuple, {"float32", "DefaultFormat", "tuple"}},
|
||||
{F32_Default_List, {"float32", "DefaultFormat", "list"}},
|
||||
{F64_None, {"float64", ""}},
|
||||
{F64_Default, {"float64", "DefaultFormat"}},
|
||||
{F64_5HD, {"float64", "NC1HWC0"}},
|
||||
{F64_FracZ, {"float64", "FRACTAL_Z"}},
|
||||
{F64_FracNZ, {"float64", "FRACTAL_NZ"}},
|
||||
{F64_C1HWNCoC0, {"float64", "C1HWNCoC0"}},
|
||||
{F64_NCHW, {"float64", "NCHW"}},
|
||||
{F64_NHWC, {"float64", "NHWC"}},
|
||||
{F64_HWCN, {"float64", "HWCN"}},
|
||||
{F64_NDHWC, {"float64", "NDHWC"}},
|
||||
{F64_ChannelLast, {"float64", "ChannelLast"}},
|
||||
{F64_Default_Tuple, {"float64", "DefaultFormat", "tuple"}},
|
||||
{F64_Default_List, {"float64", "DefaultFormat", "list"}},
|
||||
{C64_Default, {"complex64", "DefaultFormat"}},
|
||||
{C128_Default, {"complex128", "DefaultFormat"}},
|
||||
};
|
||||
|
||||
void ConvertConstScalarInputToTensor(const AnfNodePtr &input_node);
|
||||
|
||||
std::vector<TensorPtr> ConvertOutputToTensor(const mindspore::BaseRef &output);
|
||||
|
||||
AbstractBasePtr GetAbstract(const TypePtr &type, const int64_t shape[], size_t shape_size, bool is_param = false);
|
||||
#endif // MINDSPORE_CCSRC_C_API_UTILS_H_
|
||||
|
||||
STATUS CheckCustomOpInfo(const CustomOpInfo &info);
|
||||
|
||||
nlohmann::json ConvertOpInfoToJson(const CustomOpInfo &info);
|
||||
|
||||
size_t GetMaxMallocSize();
|
||||
|
||||
#define MS_ERROR_IF_FALSE_W_RET_N_LOG(condition, val, message) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
MS_LOG(ERROR) << message; \
|
||||
return val; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MS_ERROR_IF_TRUE_W_RET_N_LOG(condition, val, message) \
|
||||
do { \
|
||||
if ((condition)) { \
|
||||
MS_LOG(ERROR) << message; \
|
||||
return val; \
|
||||
} \
|
||||
} while (0)
|
||||
#endif // MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
|
||||
|
|
|
@ -1071,6 +1071,7 @@ constexpr auto kAttrFixedOutputFormat = "fixed_output_format";
|
|||
constexpr auto kAttrFixedInputDeviceShape = "fixed_input_device_shape";
|
||||
constexpr auto kAttrFixedOutputDeviceShape = "fixed_output_device_shape";
|
||||
constexpr auto kAttrFuncType = "func_type";
|
||||
constexpr auto kAttrFuncName = "func_name";
|
||||
constexpr auto kNonMaxSuppressionWithOverlapsOpName = "NonMaxSuppressionWithOverlaps";
|
||||
constexpr auto kAttrCustAicpu = "cust_aicpu";
|
||||
constexpr auto kAttrIsInternalOutputNopNode = "is_internal_output_nop_node";
|
||||
|
|
|
@ -36,11 +36,11 @@ class BACKEND_EXPORT OpLib {
|
|||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type,
|
||||
bool is_dynamic_shape = false);
|
||||
static std::map<mindspore::kernel::OpImplyType, std::map<std::string, std::shared_ptr<OpInfo>>> &GetOpInfoMap();
|
||||
static std::shared_ptr<OpInfo> DecodeOpInfo(const nlohmann::json &obj, const OpImplyType &imply_type,
|
||||
const std::string &impl_path);
|
||||
|
||||
private:
|
||||
static bool RegOpFromLocalInfo();
|
||||
static std::shared_ptr<OpInfo> DecodeOpInfo(const nlohmann::json &obj, const OpImplyType &imply_type,
|
||||
const std::string &impl_path);
|
||||
static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType &imply_type,
|
||||
const std::shared_ptr<OpInfo> &op_info);
|
||||
static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
|
||||
|
|
|
@ -44,7 +44,7 @@ TEST_F(TestCApiAttr, test_attr) {
|
|||
AttrHandle attr1 = MSNewAttrInt64(res_mgr, 1);
|
||||
ASSERT_TRUE(attr1 != nullptr);
|
||||
int64_t attr2_raw[] = {2, 2};
|
||||
AttrHandle attr2 = MSOpNewAttrs(res_mgr, attr2_raw, 2, kNumberTypeInt64);
|
||||
AttrHandle attr2 = MSNewAttrArray(res_mgr, attr2_raw, 2, MS_INT64);
|
||||
ASSERT_TRUE(attr2 != nullptr);
|
||||
char name1[] = "attr1";
|
||||
char name2[] = "attr2";
|
||||
|
@ -52,7 +52,7 @@ TEST_F(TestCApiAttr, test_attr) {
|
|||
AttrHandle attrs[] = {attr1, attr2};
|
||||
size_t attr_num = 2;
|
||||
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, kNumberTypeInt32, NULL, 0);
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, MS_INT32, NULL, 0);
|
||||
ASSERT_TRUE(x != nullptr);
|
||||
NodeHandle y = MSNewScalarConstantInt32(res_mgr, 2);
|
||||
ASSERT_TRUE(y != nullptr);
|
||||
|
@ -76,7 +76,7 @@ TEST_F(TestCApiAttr, test_attr) {
|
|||
ASSERT_EQ(attr1_retrived, 2);
|
||||
values[0] = 1;
|
||||
values[1] = 1;
|
||||
ret = MSOpSetAttrArray(res_mgr, op_add, "attr2", values, 2, kNumberTypeInt64);
|
||||
ret = MSOpSetAttrArray(res_mgr, op_add, "attr2", values, 2, MS_INT64);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ret = MSOpGetAttrArrayInt64(res_mgr, op_add, "attr2", values, 2);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(TestCApiGraph, test_multi_output_graph) {
|
|||
ASSERT_TRUE(res_mgr != nullptr);
|
||||
GraphHandle fg = MSFuncGraphCreate(res_mgr);
|
||||
ASSERT_TRUE(fg != nullptr);
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, kNumberTypeInt32, NULL, 0);
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, MS_INT32, NULL, 0);
|
||||
ASSERT_TRUE(x != nullptr);
|
||||
NodeHandle y = MSNewScalarConstantInt32(res_mgr, 2);
|
||||
ASSERT_TRUE(y != nullptr);
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(TestCApiNode, test_op_node) {
|
|||
ASSERT_TRUE(res_mgr != nullptr);
|
||||
GraphHandle fg = MSFuncGraphCreate(res_mgr);
|
||||
ASSERT_TRUE(fg != nullptr);
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, kNumberTypeInt32, NULL, 0);
|
||||
NodeHandle x = MSNewPlaceholder(res_mgr, fg, MS_INT32, NULL, 0);
|
||||
ASSERT_TRUE(x != nullptr);
|
||||
NodeHandle y = MSNewScalarConstantInt32(res_mgr, 3);
|
||||
ASSERT_TRUE(y != nullptr);
|
||||
|
@ -83,9 +83,9 @@ TEST_F(TestCApiNode, test_normal_nodes) {
|
|||
// test Tensor Variable
|
||||
int64_t a_shape[] = {1, 2};
|
||||
float a_data[] = {1.2, 3.4};
|
||||
NodeHandle a1 = MSNewTensorVariable(res_mgr, fg, a_data, kNumberTypeFloat32, a_shape, 2, 2 * sizeof(float));
|
||||
NodeHandle a1 = MSNewTensorVariable(res_mgr, fg, a_data, MS_FLOAT32, a_shape, 2, 2 * sizeof(float));
|
||||
ASSERT_TRUE(a1 != nullptr);
|
||||
TensorHandle tensor1 = MSNewTensor(res_mgr, a_data, kNumberTypeFloat32, a_shape, 2, 2 * sizeof(float));
|
||||
TensorHandle tensor1 = MSNewTensor(res_mgr, a_data, MS_FLOAT32, a_shape, 2, 2 * sizeof(float));
|
||||
ASSERT_TRUE(tensor1 != nullptr);
|
||||
NodeHandle a2 = MSNewTensorVariableFromTensor(res_mgr, fg, tensor1);
|
||||
ASSERT_TRUE(a2 != nullptr);
|
||||
|
@ -105,9 +105,9 @@ TEST_F(TestCApiNode, test_normal_nodes) {
|
|||
// test Tensor Constant
|
||||
int64_t b_shape[] = {1, 2};
|
||||
int b_data[] = {4, 3};
|
||||
NodeHandle b1 = MSNewTensorConstant(res_mgr, b_data, kNumberTypeInt32, b_shape, 2, 2 * sizeof(int));
|
||||
NodeHandle b1 = MSNewTensorConstant(res_mgr, b_data, MS_INT32, b_shape, 2, 2 * sizeof(int));
|
||||
ASSERT_TRUE(b1 != nullptr);
|
||||
TensorHandle tensor2 = MSNewTensor(res_mgr, b_data, kNumberTypeInt32, b_shape, 2, 2 * sizeof(int));
|
||||
TensorHandle tensor2 = MSNewTensor(res_mgr, b_data, MS_INT32, b_shape, 2, 2 * sizeof(int));
|
||||
ASSERT_TRUE(tensor2 != nullptr);
|
||||
NodeHandle b2 = MSNewTensorConstantFromTensor(res_mgr, tensor2);
|
||||
ASSERT_TRUE(b2 != nullptr);
|
||||
|
@ -157,11 +157,11 @@ TEST_F(TestCApiNode, test_normal_nodes) {
|
|||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(vec_get[0], vec[0]);
|
||||
ASSERT_EQ(vec_get[1], vec[1]);
|
||||
NodeHandle x6 = MSNewTypeConstant(res_mgr, kNumberTypeInt32);
|
||||
NodeHandle x6 = MSNewTypeConstant(res_mgr, MS_INT32);
|
||||
ASSERT_TRUE(x6 != nullptr);
|
||||
TypeId value_6 = MSTypeConstantGetValue(res_mgr, x6, &ret);
|
||||
DataTypeC value_6 = MSTypeConstantGetValue(res_mgr, x6, &ret);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(value_6, kNumberTypeInt32);
|
||||
ASSERT_EQ(value_6, MS_INT32);
|
||||
MSResourceManagerDestroy(res_mgr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,9 +38,9 @@ TEST_F(TestCApiTensor, test_new_tensor) {
|
|||
ASSERT_TRUE(res_mgr != nullptr);
|
||||
float data_value[] = {0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int64_t data_shape[] = {1, 1, 3, 3};
|
||||
TensorHandle tensor_false = MSNewTensor(res_mgr, NULL, kNumberTypeFloat32, data_shape, 4, 9 * sizeof(float));
|
||||
TensorHandle tensor_false = MSNewTensor(res_mgr, NULL, MS_FLOAT32, data_shape, 4, 9 * sizeof(float));
|
||||
ASSERT_TRUE(tensor_false == nullptr);
|
||||
TensorHandle tensor = MSNewTensor(res_mgr, data_value, kNumberTypeFloat32, data_shape, 4, 9 * sizeof(float));
|
||||
TensorHandle tensor = MSNewTensor(res_mgr, data_value, MS_FLOAT32, data_shape, 4, 9 * sizeof(float));
|
||||
ASSERT_TRUE(tensor != nullptr);
|
||||
size_t ele_num = MSTensorGetElementNum(res_mgr, tensor, &ret);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
|
@ -54,11 +54,11 @@ TEST_F(TestCApiTensor, test_new_tensor) {
|
|||
ASSERT_EQ(res[0], 0);
|
||||
ASSERT_EQ(res[4], 4);
|
||||
ASSERT_EQ(res[8], 8);
|
||||
ret = MSTensorSetDataType(res_mgr, tensor, kNumberTypeInt32);
|
||||
ret = MSTensorSetDataType(res_mgr, tensor, MS_INT32);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
TypeId type = MSTensorGetDataType(res_mgr, tensor, &ret);
|
||||
DataTypeC type = MSTensorGetDataType(res_mgr, tensor, &ret);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(type, kNumberTypeInt32);
|
||||
ASSERT_EQ(type, MS_INT32);
|
||||
int64_t new_shape[] = {2, 3, 4, 5};
|
||||
ret = MSTensorSetShape(res_mgr, tensor, new_shape, 4);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
|
@ -81,13 +81,13 @@ TEST_F(TestCApiTensor, test_new_tensor_with_src_type) {
|
|||
ASSERT_TRUE(res_mgr != nullptr);
|
||||
float data_value[] = {0, 1, 2, 3};
|
||||
int64_t data_shape[] = {1, 1, 2, 2};
|
||||
TensorHandle tensor_false = MSNewTensorWithSrcType(res_mgr, data_value, NULL, 4, kNumberTypeInt32, kNumberTypeFloat32);
|
||||
TensorHandle tensor_false = MSNewTensorWithSrcType(res_mgr, data_value, NULL, 4, MS_INT32, MS_FLOAT32);
|
||||
ASSERT_TRUE(tensor_false == nullptr);
|
||||
TensorHandle tensor = MSNewTensorWithSrcType(res_mgr, data_value, data_shape, 4, kNumberTypeInt32, kNumberTypeFloat32);
|
||||
TensorHandle tensor = MSNewTensorWithSrcType(res_mgr, data_value, data_shape, 4, MS_INT32, MS_FLOAT32);
|
||||
ASSERT_TRUE(tensor != nullptr);
|
||||
TypeId type = MSTensorGetDataType(res_mgr, tensor, &ret);
|
||||
DataTypeC type = MSTensorGetDataType(res_mgr, tensor, &ret);
|
||||
ASSERT_EQ(ret, RET_OK);
|
||||
ASSERT_EQ(type, kNumberTypeInt32);
|
||||
ASSERT_EQ(type, MS_INT32);
|
||||
MSResourceManagerDestroy(res_mgr);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue